Commit 9d8d4386 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix bug: shuffle rstate in channels

parent 671ed3c6
This diff is collapsed.
......@@ -95,10 +95,9 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad
return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(6, 7))
@partial(jax.jit, static_argnums=(5, 6))
def compute_gae_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
next_value, values, rewards, next_dones, switch, gamma, gae_lambda,
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam = carry
......@@ -113,21 +112,20 @@ def compute_gae_2p0s(
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, lastgaelam
_, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(6, 7))
@partial(jax.jit, static_argnums=(5, 6))
def compute_gae_upgo_2p0s(
next_value, next_done, values, rewards, dones, switch,
next_value, values, rewards, next_dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
......@@ -150,13 +148,12 @@ def compute_gae_upgo_2p0s(
carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
return returns - values, advantages + values
......
import jax
import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics
......@@ -13,4 +14,13 @@ def masked_normalize(x, valid, epsilon=1e-8):
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
\ No newline at end of file
return (x - mean) / jnp.sqrt(variance + epsilon)
def categorical_sample(logits, key):
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=-1)
return action, key
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment