Commit b8929b9c authored by sbl1996@126.com's avatar sbl1996@126.com

refactor agent2

parent 8b35ba28
...@@ -52,6 +52,10 @@ class Args: ...@@ -52,6 +52,10 @@ class Args:
tb_dir: str = "runs" tb_dir: str = "runs"
"""the directory to save the tensorboard logs""" """the directory to save the tensorboard logs"""
tb_offset: int = 0
"""the step offset of the tensorboard logs"""
run_name: Optional[str] = None
"""the name of the tensorboard run"""
ckpt_dir: str = "checkpoints" ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints""" """the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None gcs_bucket: Optional[str] = None
...@@ -223,7 +227,7 @@ class Transition(NamedTuple): ...@@ -223,7 +227,7 @@ class Transition(NamedTuple):
next_dones: list next_dones: list
def create_agent(args, multi_step=False, eval=False): def create_agent(args, eval=False):
return RNNAgent( return RNNAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
...@@ -232,7 +236,6 @@ def create_agent(args, multi_step=False, eval=False): ...@@ -232,7 +236,6 @@ def create_agent(args, multi_step=False, eval=False):
param_dtype=jnp.float32, param_dtype=jnp.float32,
rnn_channels=args.rnn_channels, rnn_channels=args.rnn_channels,
switch=args.switch, switch=args.switch,
multi_step=multi_step,
freeze_id=args.freeze_id, freeze_id=args.freeze_id,
use_history=args.use_history if not eval else args.eval_use_history, use_history=args.use_history if not eval else args.eval_use_history,
rnn_type=args.rnn_type if not eval else args.eval_rnn_type, rnn_type=args.rnn_type if not eval else args.eval_rnn_type,
...@@ -480,23 +483,26 @@ def rollout( ...@@ -480,23 +483,26 @@ def rollout(
avg_episodic_length = np.mean(envs.returned_episode_lengths) avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time)) SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start)) SPS_update = int(args.batch_size / (time.time() - update_time_start))
tb_global_step = args.tb_offset + global_step
if device_thread_id == 0: if device_thread_id == 0:
print( print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}" f"global_step={tb_global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
) )
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S") time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print( print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, " f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}" f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
) )
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step) writer.add_scalar("stats/rollout_time", np.mean(rollout_time), tb_global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, tb_global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step) writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, tb_global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step) writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), tb_global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step) writer.add_scalar("stats/inference_time", inference_time, tb_global_step)
writer.add_scalar("stats/env_time", env_time, global_step) writer.add_scalar("stats/env_time", env_time, tb_global_step)
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, tb_global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step) writer.add_scalar("charts/SPS_update", SPS_update, tb_global_step)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -554,8 +560,12 @@ if __name__ == "__main__": ...@@ -554,8 +560,12 @@ if __name__ == "__main__":
args.learner_devices = [str(item) for item in learner_devices] args.learner_devices = [str(item) for item in learner_devices]
pprint(args) pprint(args)
timestamp = int(time.time()) if args.run_name is None:
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" timestamp = int(time.time())
run_name = f"{args.exp_name}__{args.seed}__{timestamp}"
else:
run_name = args.run_name
timestamp = int(run_name.split("__")[-1])
dummy_writer = SimpleNamespace() dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None dummy_writer.add_scalar = lambda x, y, z: None
...@@ -668,7 +678,7 @@ if __name__ == "__main__": ...@@ -668,7 +678,7 @@ if __name__ == "__main__":
inputs = (rstate1, rstate2, obs, dones, switch_or_mains) inputs = (rstate1, rstate2, obs, dones, switch_or_mains)
_rstate, new_logits, new_values, _valid = create_agent( _rstate, new_logits, new_values, _valid = create_agent(
args, multi_step=True).apply(params, inputs) args).apply(params, inputs)
new_values = new_values.squeeze(-1) new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
...@@ -897,13 +907,14 @@ if __name__ == "__main__": ...@@ -897,13 +907,14 @@ if __name__ == "__main__":
if eval_stats is not None: if eval_stats is not None:
eval_stat_list.append(eval_stats) eval_stat_list.append(eval_stats)
tb_global_step = args.tb_offset + global_step
if update % args.eval_interval == 0: if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0) eval_stats = np.mean(eval_stat_list, axis=0)
eval_stats = jax.device_put(eval_stats, local_devices[0]) eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0]) eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats eval_time, eval_return, eval_win_rate = eval_stats
writer.add_scalar(f"charts/eval_return", eval_return, global_step) writer.add_scalar(f"charts/eval_return", eval_return, tb_global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step) writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, tb_global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}") print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
...@@ -927,31 +938,31 @@ if __name__ == "__main__": ...@@ -927,31 +938,31 @@ if __name__ == "__main__":
# record rewards for plotting purposes # record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0: if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step) writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), tb_global_step)
writer.add_scalar( writer.add_scalar(
"stats/rollout_params_queue_get_time_diff", "stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time, np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step, tb_global_step,
) )
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step) writer.add_scalar("stats/training_time", time.time() - training_time_start, tb_global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step) writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), tb_global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step) writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), tb_global_step)
print( print(
f"{global_step} actor_update={update}, " f"{tb_global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, " f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}" f"data_time={rollout_queue_get_time[-1]:.2f}"
) )
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), tb_global_step
) )
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step) writer.add_scalar("losses/value_loss", v_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step) writer.add_scalar("losses/entropy", entropy_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step) writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), tb_global_step)
writer.add_scalar("losses/loss", loss, global_step) writer.add_scalar("losses/loss", loss, tb_global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 2**20 M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model" ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name) ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None: if args.gcs_bucket is not None:
......
...@@ -339,7 +339,6 @@ class RNNAgent(nn.Module): ...@@ -339,7 +339,6 @@ class RNNAgent(nn.Module):
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True use_history: bool = True
...@@ -347,7 +346,8 @@ class RNNAgent(nn.Module): ...@@ -347,7 +346,8 @@ class RNNAgent(nn.Module):
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
if self.multi_step: multi_step = len(inputs) != 2
if multi_step:
# (num_steps * batch_size, ...) # (num_steps * batch_size, ...)
*rstate, x, done, switch_or_main = inputs *rstate, x, done, switch_or_main = inputs
else: else:
...@@ -380,7 +380,7 @@ class RNNAgent(nn.Module): ...@@ -380,7 +380,7 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'none': elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1) f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else: else:
if self.multi_step: if multi_step:
rstate1, rstate2 = rstate rstate1, rstate2 = rstate
batch_size = jax.tree.leaves(rstate1)[0].shape[0] batch_size = jax.tree.leaves(rstate1)[0].shape[0]
num_steps = done.shape[0] // batch_size num_steps = done.shape[0] // batch_size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment