Commit 598465e8 authored by biluo.shen's avatar biluo.shen

reduce_gradient out compile

parent 0559e98c
...@@ -267,7 +267,6 @@ def run(local_rank, world_size): ...@@ -267,7 +267,6 @@ def run(local_rank, world_size):
optimizer.zero_grad() optimizer.zero_grad()
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs): def predict_step(agent, next_obs):
...@@ -403,6 +402,7 @@ def run(local_rank, world_size): ...@@ -403,6 +402,7 @@ def run(local_rank, world_size):
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds])
reduce_gradidents(agent, args.world_size)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment