Commit 32aa61bf authored by biluo.shen's avatar biluo.shen

print twice

parent 366e5f3b
...@@ -317,6 +317,7 @@ def run(local_rank, world_size): ...@@ -317,6 +317,7 @@ def run(local_rank, world_size):
dones[step] = next_done dones[step] = next_done
_start = time.time() _start = time.time()
torch._inductor.cudagraph_mark_step_begin()
logits, value = predict_step(agent, next_obs) logits, value = predict_step(agent, next_obs)
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -334,6 +335,9 @@ def run(local_rank, world_size): ...@@ -334,6 +335,9 @@ def run(local_rank, world_size):
rewards[step] = to_tensor(reward) rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_) next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_)
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
if not writer: if not writer:
continue continue
...@@ -358,8 +362,6 @@ def run(local_rank, world_size): ...@@ -358,8 +362,6 @@ def run(local_rank, world_size):
avg_win_rates = [] avg_win_rates = []
avg_ep_returns = [] avg_ep_returns = []
collect_time = time.time() - collect_start
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1) next_value = agent.get_value(next_obs).reshape(1, -1)
...@@ -399,6 +401,7 @@ def run(local_rank, world_size): ...@@ -399,6 +401,7 @@ def run(local_rank, world_size):
mb_obs = { mb_obs = {
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
torch._inductor.cudagraph_mark_step_begin()
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds])
...@@ -413,8 +416,9 @@ def run(local_rank, world_size): ...@@ -413,8 +416,9 @@ def run(local_rank, world_size):
train_time = time.time() - _start train_time = time.time() - _start
if local_rank == 0: print(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}", flush=True)
print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") # if local_rank == 0:
# print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
......
...@@ -70,7 +70,7 @@ class Args: ...@@ -70,7 +70,7 @@ class Args:
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99 gamma: float = 0.997
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
...@@ -340,6 +340,9 @@ def run(local_rank, world_size): ...@@ -340,6 +340,9 @@ def run(local_rank, world_size):
rewards[step] = to_tensor(reward) rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_) next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_)
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
if not writer: if not writer:
continue continue
...@@ -350,7 +353,7 @@ def run(local_rank, world_size): ...@@ -350,7 +353,7 @@ def run(local_rank, world_size):
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
if info['is_selfplay'][idx]: if info['is_selfplay'][idx]:
# win rate for the first player # win rate for the first player
pl = 1 if to_play[idx] == 0 else -1 pl = 1 if next_to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1 winner = 0 if episode_reward * pl > 0 else 1
avg_sp_win_rates.append(1 - winner) avg_sp_win_rates.append(1 - winner)
else: else:
...@@ -375,8 +378,6 @@ def run(local_rank, world_size): ...@@ -375,8 +378,6 @@ def run(local_rank, world_size):
writer.add_scalar("charts/avg_sp_win_rate", np.mean(avg_sp_win_rates), global_step) writer.add_scalar("charts/avg_sp_win_rate", np.mean(avg_sp_win_rates), global_step)
avg_sp_win_rates = [] avg_sp_win_rates = []
collect_time = time.time() - collect_start
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1) next_value = agent.get_value(next_obs).reshape(1, -1)
...@@ -437,8 +438,9 @@ def run(local_rank, world_size): ...@@ -437,8 +438,9 @@ def run(local_rank, world_size):
train_time = time.time() - _start train_time = time.time() - _start
if local_rank == 0: print(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}", flush=True)
print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") # if local_rank == 0:
# print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment