Commit 537b59fb authored by sbl1996@126.com's avatar sbl1996@126.com

Add burn in follow R2D2

parent 63e558fb
......@@ -105,6 +105,8 @@ class Args:
"""Toggle the use of switch mechanism"""
norm_adv: bool = False
"""Toggles advantages normalization"""
burn_in_steps: Optional[int] = None
"""the number of burn-in steps for training (for R2D2)"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
......@@ -661,7 +663,6 @@ if __name__ == "__main__":
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
if args.switch:
dones = dones | next_dones
......@@ -672,7 +673,7 @@ if __name__ == "__main__":
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
approx_kl = (ratios - 1) - logratio
new_values_, rewards, next_dones, switch_or_mains = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch_or_mains),
......@@ -709,20 +710,23 @@ if __name__ == "__main__":
else:
pg_advs = jnp.clip(ratios, args.rho_clip_min, args.rho_clip_max) * advantages
pg_loss = policy_gradient_loss(new_logits, actions, pg_advs)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask)
if args.vloss_clip is not None:
v_loss = jnp.minimum(v_loss, args.vloss_clip)
ent_loss = entropy_loss(new_logits)
ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
ent_loss = ent_loss / n_valids
if args.burn_in_steps:
mask = jax.tree.map(
lambda x: x.reshape(num_steps, num_envs), mask)
burn_in_mask = jnp.arange(num_steps) < args.burn_in_steps
mask = jnp.where(burn_in_mask[:, None], 0.0, mask)
mask = jnp.reshape(mask, (-1,))
if args.vloss_clip is not None:
v_loss = jnp.minimum(v_loss, args.vloss_clip)
n_valids = jnp.sum(mask)
pg_loss, v_loss, ent_loss, approx_kl = jax.tree.map(
lambda x: jnp.sum(x * mask) / n_valids, (pg_loss, v_loss, ent_loss, approx_kl))
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment