Commit 45885f81 authored by Biluo Shen's avatar Biluo Shen

Add apply_if_finite in opt

parent 7f80e9a8
......@@ -23,6 +23,7 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
......@@ -45,6 +46,13 @@ class Args:
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
......@@ -151,7 +159,7 @@ class Args:
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
......@@ -168,7 +176,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if mode == 'self' else True,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
)
envs.num_envs = num_envs
......@@ -231,7 +239,7 @@ def rollout(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode)
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
......@@ -431,14 +439,13 @@ def rollout(
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
......@@ -446,12 +453,14 @@ def rollout(
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__":
......@@ -508,12 +517,21 @@ if __name__ == "__main__":
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
def save_fn(obj, path):
with open(path, "wb") as f:
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
......@@ -559,12 +577,12 @@ if __name__ == "__main__":
),
every_k_schedule=1,
)
tx = optax.apply_if_finite(tx, max_consecutive_errors=3)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
......@@ -589,7 +607,7 @@ if __name__ == "__main__":
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def ppo_loss(
def loss_fn(
params, rstate1, rstate2, obs, dones, mains,
actions, logits, rewards, mask, next_value, next_done):
# (num_steps * local_num_envs // n_mb))
......@@ -663,7 +681,7 @@ if __name__ == "__main__":
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
......@@ -693,7 +711,7 @@ if __name__ == "__main__":
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
......@@ -818,7 +836,7 @@ if __name__ == "__main__":
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
......@@ -827,15 +845,13 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
zip_file_path = "latest.zip"
zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
break
......
......@@ -587,12 +587,12 @@ if __name__ == "__main__":
),
every_k_schedule=1,
)
tx = optax.apply_if_finite(tx, max_consecutive_errors=3)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
......@@ -862,7 +862,7 @@ if __name__ == "__main__":
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment