Commit 892c7364 authored by sbl1996@126.com's avatar sbl1996@126.com

refactor PPO

parent 2bf8ce6a
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -101,21 +101,22 @@ def compute_gae_2p0s(
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_values, next_values, lastgaelam = carry
next_done, curvalues, reward, switch = inp
nextnonterminal = 1.0 - next_done
boot_value, boot_done, next_value, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_values = jnp.where(switch, -pred_values, next_values)
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
delta = reward + gamma * next_values * nextnonterminal - curvalues
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
return (pred_values, curvalues, lastgaelam), lastgaelam
gamma_ = gamma * (1.0 - next_done)
delta = reward + gamma_ * next_value - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, lastgaelam
carry = next_value, next_done, next_value, lastgaelam
_, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
......@@ -130,28 +131,29 @@ def compute_gae_upgo_2p0s(
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_value, next_value, next_q, last_return, lastgaelam = carry
next_done, curvalues, reward, switch = inp
gamma_ = gamma * (1.0 - next_done)
boot_value, boot_done, next_value, next_q, last_return, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_value = jnp.where(switch, -pred_value, next_value)
next_q = jnp.where(switch, -pred_value, next_q)
last_return = jnp.where(switch, -pred_value, last_return)
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
lastgaelam = jnp.where(switch, 0, lastgaelam)
gamma_ = gamma * (1.0 - next_done)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
delta = next_q - curvalues
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
carry = pred_value, next_value, next_q, last_return, lastgaelam
carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, next_value, next_value, lastgaelam
carry = next_value, next_done, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment