Commit 8b35ba28 authored by sbl1996@126.com's avatar sbl1996@126.com

Set nan or inf loss to 0

parent ec9f3e0c
...@@ -730,6 +730,7 @@ if __name__ == "__main__": ...@@ -730,6 +730,7 @@ if __name__ == "__main__":
lambda x: jnp.sum(x * mask) / n_valids, (pg_loss, v_loss, ent_loss, approx_kl)) lambda x: jnp.sum(x * mask) / n_valids, (pg_loss, v_loss, ent_loss, approx_kl))
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update( def single_device_update(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment