Commit 9d8d4386 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix bug: shuffle rstate in channels

parent 671ed3c6
This diff is collapsed.
...@@ -95,10 +95,9 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad ...@@ -95,10 +95,9 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad
return -jnp.mean(clipped_objective * mask) return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(6, 7)) @partial(jax.jit, static_argnums=(5, 6))
def compute_gae_2p0s( def compute_gae_2p0s(
next_value, next_done, values, rewards, dones, switch, next_value, values, rewards, next_dones, switch, gamma, gae_lambda,
gamma, gae_lambda,
): ):
def body_fn(carry, inp): def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam = carry boot_value, boot_done, next_value, lastgaelam = carry
...@@ -113,21 +112,20 @@ def compute_gae_2p0s( ...@@ -113,21 +112,20 @@ def compute_gae_2p0s(
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0) next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value) lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, lastgaelam carry = next_value, next_done, next_value, lastgaelam
_, advantages = jax.lax.scan( _, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True body_fn, carry, (next_dones, values, rewards, switch), reverse=True
) )
target_values = advantages + values target_values = advantages + values
return advantages, target_values return advantages, target_values
@partial(jax.jit, static_argnums=(6, 7)) @partial(jax.jit, static_argnums=(5, 6))
def compute_gae_upgo_2p0s( def compute_gae_upgo_2p0s(
next_value, next_done, values, rewards, dones, switch, next_value, values, rewards, next_dones, switch,
gamma, gae_lambda, gamma, gae_lambda,
): ):
def body_fn(carry, inp): def body_fn(carry, inp):
...@@ -150,13 +148,12 @@ def compute_gae_upgo_2p0s( ...@@ -150,13 +148,12 @@ def compute_gae_upgo_2p0s(
carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return) return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0) next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value) lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, next_value, next_value, lastgaelam carry = next_value, next_done, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan( _, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True body_fn, carry, (next_dones, values, rewards, switch), reverse=True
) )
return returns - values, advantages + values return returns - values, advantages + values
......
import jax
import jax.numpy as jnp import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics from ygoai.rl.env import RecordEpisodeStatistics
...@@ -14,3 +15,12 @@ def masked_normalize(x, valid, epsilon=1e-8): ...@@ -14,3 +15,12 @@ def masked_normalize(x, valid, epsilon=1e-8):
mean = x.sum() / n mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon) return (x - mean) / jnp.sqrt(variance + epsilon)
def categorical_sample(logits, key):
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=-1)
return action, key
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment