Commit 45885f81 authored by Biluo Shen's avatar Biluo Shen

Add apply_if_finite in opt

parent 7f80e9a8
...@@ -23,6 +23,7 @@ from rich.pretty import pprint ...@@ -23,6 +23,7 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
...@@ -45,6 +46,13 @@ class Args: ...@@ -45,6 +46,13 @@ class Args:
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the path to the model checkpoint to load""" """the path to the model checkpoint to load"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
...@@ -151,7 +159,7 @@ class Args: ...@@ -151,7 +159,7 @@ class Args:
freeze_id: bool = False freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1): def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity: if not args.thread_affinity:
thread_affinity_offset = -1 thread_affinity_offset = -1
if thread_affinity_offset >= 0: if thread_affinity_offset >= 0:
...@@ -168,7 +176,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -168,7 +176,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
async_reset=False, async_reset=False,
greedy_reward=args.greedy_reward if mode == 'self' else True, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
...@@ -231,7 +239,7 @@ def rollout( ...@@ -231,7 +239,7 @@ def rollout(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode) args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -431,14 +439,13 @@ def rollout( ...@@ -431,14 +439,13 @@ def rollout(
_start = time.time() _start = time.time()
if eval_mode == 'bot': if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x) predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate( eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0] eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
metric_name = "eval_return"
else: else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x) predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle( eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2] eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
metric_name = "eval_win_rate" eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0: if device_thread_id != 0:
eval_queue.put(eval_stat) eval_queue.put(eval_stat)
else: else:
...@@ -446,12 +453,14 @@ def rollout( ...@@ -446,12 +453,14 @@ def rollout(
eval_stats.append(eval_stat) eval_stats.append(eval_stat)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(eval_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.stack(eval_stats)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step) eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0: if device_thread_id == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -508,12 +517,21 @@ if __name__ == "__main__": ...@@ -508,12 +517,21 @@ if __name__ == "__main__":
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}") tb_log_dir = f"{args.tb_dir}/{run_name}"
writer = SummaryWriter(tb_log_dir)
writer.add_text( writer.add_text(
"hyperparameters", "hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
) )
def save_fn(obj, path):
with open(path, "wb") as f:
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3)
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -559,12 +577,12 @@ if __name__ == "__main__": ...@@ -559,12 +577,12 @@ if __name__ == "__main__":
), ),
every_k_schedule=1, every_k_schedule=1,
) )
tx = optax.apply_if_finite(tx, max_consecutive_errors=3)
agent_state = TrainState.create( agent_state = TrainState.create(
apply_fn=None, apply_fn=None,
params=params, params=params,
tx=tx, tx=tx,
) )
if args.checkpoint: if args.checkpoint:
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
...@@ -589,7 +607,7 @@ if __name__ == "__main__": ...@@ -589,7 +607,7 @@ if __name__ == "__main__":
args, multi_step=True).apply(params, inputs) args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1) return logits, value.squeeze(-1)
def ppo_loss( def loss_fn(
params, rstate1, rstate2, obs, dones, mains, params, rstate1, rstate2, obs, dones, mains,
actions, logits, rewards, mask, next_value, next_done): actions, logits, rewards, mask, next_value, next_done):
# (num_steps * local_num_envs // n_mb)) # (num_steps * local_num_envs // n_mb))
...@@ -663,7 +681,7 @@ if __name__ == "__main__": ...@@ -663,7 +681,7 @@ if __name__ == "__main__":
# main first, opponent second # main first, opponent second
num_steps, num_envs = storage.rewards.shape num_steps, num_envs = storage.rewards.shape
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def update_epoch(carry, _): def update_epoch(carry, _):
agent_state, key = carry agent_state, key = carry
...@@ -693,7 +711,7 @@ if __name__ == "__main__": ...@@ -693,7 +711,7 @@ if __name__ == "__main__":
shuffled_mask = jnp.ones_like(shuffled_storage.mains) shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch) agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices") grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads) agent_state = agent_state.apply_gradients(grads=grads)
...@@ -818,7 +836,7 @@ if __name__ == "__main__": ...@@ -818,7 +836,7 @@ if __name__ == "__main__":
f"data_time={rollout_queue_get_time[-1]:.2f}" f"data_time={rollout_queue_get_time[-1]:.2f}"
) )
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step
) )
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step) writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
...@@ -827,15 +845,13 @@ if __name__ == "__main__": ...@@ -827,15 +845,13 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step) writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20) M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model") ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
with open(model_path, "wb") as f: ckpt_maneger.save(unreplicated_params, ckpt_name)
f.write( if args.gcs_bucket is not None:
flax.serialization.to_bytes(unreplicated_params) zip_file_path = "latest.zip"
) zip_files(zip_file_path, [ckpt_maneger.get_latest(), tb_log_dir])
print(f"model saved to {model_path}") sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates: if learner_policy_version >= args.num_updates:
break break
......
...@@ -587,12 +587,12 @@ if __name__ == "__main__": ...@@ -587,12 +587,12 @@ if __name__ == "__main__":
), ),
every_k_schedule=1, every_k_schedule=1,
) )
tx = optax.apply_if_finite(tx, max_consecutive_errors=3)
agent_state = TrainState.create( agent_state = TrainState.create(
apply_fn=None, apply_fn=None,
params=params, params=params,
tx=tx, tx=tx,
) )
if args.checkpoint: if args.checkpoint:
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
...@@ -862,7 +862,7 @@ if __name__ == "__main__": ...@@ -862,7 +862,7 @@ if __name__ == "__main__":
f"data_time={rollout_queue_get_time[-1]:.2f}" f"data_time={rollout_queue_get_time[-1]:.2f}"
) )
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step
) )
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step) writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment