Commit 81d80f7f authored by sbl1996@126.com's avatar sbl1996@126.com

Fix done for rstate

parent 3bf0bc91
......@@ -157,24 +157,26 @@ if __name__ == "__main__":
params2 = jax.device_put(params2)
@jax.jit
def get_probs(params, rstate, obs, done):
def get_probs(params, rstate, obs, done=None):
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
if done is not None:
next_rstate = jnp.where(done[:, None], 0, next_rstate)
return next_rstate, probs
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs, done)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, done)
next_rstate1, probs1 = get_probs(params1, rstate1, obs)
next_rstate2, probs2 = get_probs(params2, rstate2, obs)
probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done):
......
import os
import shutil
import queue
import random
import threading
......@@ -45,6 +46,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
......@@ -156,7 +159,7 @@ class Args:
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: bool = False
freeze_id: Optional[bool] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
......@@ -253,28 +256,27 @@ def rollout(
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
params: flax.core.FrozenDict, inputs):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
rstate, logits = get_logits(params, inputs)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs))
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
......@@ -284,12 +286,14 @@ def rollout(
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
......@@ -517,13 +521,18 @@ if __name__ == "__main__":
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0 and not args.debug:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
else:
writer = dummy_writer
def save_fn(obj, path):
with open(path, "wb") as f:
......@@ -669,6 +678,7 @@ if __name__ == "__main__":
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
......@@ -756,8 +766,6 @@ if __name__ == "__main__":
params_queues = []
rollout_queues = []
eval_queue = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
......@@ -844,13 +852,16 @@ if __name__ == "__main__":
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // (2**20)
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
shutil.copyfile(lastest_path, copy_path)
zip_file_path = "latest.zip"
zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir])
zip_files(zip_file_path, [str(copy_path), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
......
......@@ -47,6 +47,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
......@@ -156,7 +158,7 @@ class Args:
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: bool = False
freeze_id: Optional[bool] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
......@@ -254,28 +256,27 @@ def rollout(
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
params: flax.core.FrozenDict, inputs):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
rstate, logits = get_logits(params, inputs)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs))
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
......@@ -285,12 +286,14 @@ def rollout(
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
......@@ -532,7 +535,7 @@ if __name__ == "__main__":
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0:
if args.local_rank == 0 and not args.debug:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
......@@ -692,6 +695,7 @@ if __name__ == "__main__":
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
......@@ -874,7 +878,7 @@ if __name__ == "__main__":
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // (2**20)
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
......
......@@ -169,8 +169,6 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
if self.freeze_id:
id_embed = lambda x: jax.lax.stop_gradient(id_embed(x))
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......@@ -184,6 +182,8 @@ class Encoder(nn.Module):
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
# Cards
f_cards = CardEncoder(
......@@ -215,9 +215,12 @@ class Encoder(nn.Module):
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
......@@ -379,9 +382,9 @@ class PPOLSTMAgent(nn.Module):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
......
......@@ -48,7 +48,7 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
return deck_name
def load_embeddings(embedding_file, code_list_file):
def load_embeddings(embedding_file, code_list_file, pad_to=999):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
......@@ -56,4 +56,8 @@ def load_embeddings(embedding_file, code_list_file):
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
if pad_to is not None:
assert pad_to >= len(embeddings), f"pad_to={pad_to} < len(embeddings)={len(embeddings)}"
pad = np.zeros((pad_to - len(embeddings), embeddings.shape[1]), dtype=np.float32)
embeddings = np.concatenate([embeddings, pad], axis=0)
return embeddings
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment