Commit 598465e8 authored by biluo.shen's avatar biluo.shen

reduce_gradient out compile

parent 0559e98c
......@@ -267,7 +267,6 @@ def run(local_rank, world_size):
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs):
......@@ -403,6 +402,7 @@ def run(local_rank, world_size):
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds])
reduce_gradidents(agent, args.world_size)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment