Commit 3dfee5f5 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix int_reward compute

parent e6dc7744
...@@ -162,7 +162,7 @@ class Args: ...@@ -162,7 +162,7 @@ class Args:
"""proportion of exp used for predictor update""" """proportion of exp used for predictor update"""
rnd_episodic: bool = False rnd_episodic: bool = False
"""whether to use episodic intrinsic reward for RND""" """whether to use episodic intrinsic reward for RND"""
rnd_norm: Literal["default", "min_max"] = "default" rnd_norm: Literal["default", "min_max", "min_max2"] = "default"
"""the normalization method for RND intrinsic reward""" """the normalization method for RND intrinsic reward"""
int_coef: float = 0.5 int_coef: float = 0.5
"""coefficient of intrinsic reward, 0.0 to disable RND""" """coefficient of intrinsic reward, 0.0 to disable RND"""
...@@ -393,6 +393,15 @@ def rollout( ...@@ -393,6 +393,15 @@ def rollout(
rstate1, rstate2 = jax.tree.map( rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2)) lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1) return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def compute_int_rew(params_rt, params_rp, obs):
target_feats = rnd_target.apply(params_rt, obs)
predict_feats = rnd_predictor.apply(params_rp, obs)
int_rewards = jnp.sum((target_feats - predict_feats) ** 2, axis=-1) / 2
if args.rnd_norm == 'min_max':
int_rewards = (int_rewards - int_rewards.min()) / (int_rewards.max() - int_rewards.min() + 1e-8)
return target_feats, int_rewards
@jax.jit @jax.jit
def sample_action( def sample_action(
...@@ -403,12 +412,8 @@ def rollout( ...@@ -403,12 +412,8 @@ def rollout(
action, key = categorical_sample(logits, key) action, key = categorical_sample(logits, key)
if args.enable_rnd: if args.enable_rnd:
target_feats = rnd_target.apply(params_rt, next_obs) target_feats, int_rewards = compute_int_rew(params_rt, params_rp, next_obs)
predict_feats = rnd_predictor.apply(params_rp, next_obs) if args.rnd_norm == 'default':
int_rewards = jnp.sum((target_feats - predict_feats) ** 2, axis=-1) / 2
if args.rnd_norm == 'min_max':
int_rewards = (int_rewards - int_rewards.min()) / (int_rewards.max() - int_rewards.min() + 1e-8)
else:
rewems = rewems * args.int_gamma + int_rewards rewems = rewems * args.int_gamma + int_rewards
else: else:
target_feats = int_rewards = None target_feats = int_rewards = None
...@@ -442,7 +447,7 @@ def rollout( ...@@ -442,7 +447,7 @@ def rollout(
np.random.shuffle(main_player) np.random.shuffle(main_player)
storage = [] storage = []
reward_rms = jax.device_put(RunningMeanStd.create()) reward_rms = RunningMeanStd()
rewems = jnp.zeros(args.local_num_envs, dtype=jnp.float32, device=actor_device) rewems = jnp.zeros(args.local_num_envs, dtype=jnp.float32, device=actor_device)
@jax.jit @jax.jit
...@@ -550,16 +555,28 @@ def rollout( ...@@ -550,16 +555,28 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start) rollout_time.append(time.time() - rollout_time_start)
if args.enable_rnd: if args.enable_rnd:
next_int_reward = compute_int_rew(params_rt, params_rp, next_obs)[1]
all_int_rewards = all_int_rewards[1:] + [next_int_reward]
# TODO: update every step # TODO: update every step
all_int_rewards = jnp.stack(all_int_rewards) all_int_rewards = jnp.stack(all_int_rewards)
if args.rnd_norm == 'default': if args.rnd_norm == 'default':
reward_rms = reward_rms.update(jnp.array(all_dis_int_rewards).flatten()) all_dis_int_rewards = jnp.concatenate(all_dis_int_rewards)
all_int_rewards = all_int_rewards / jnp.sqrt(reward_rms.var) mean, std = jax.device_get((
all_dis_int_rewards.mean(), all_dis_int_rewards.std()))
count = len(all_dis_int_rewards)
reward_rms.update_from_moments(mean, std**2, count)
all_int_rewards = all_int_rewards / np.sqrt(reward_rms.var)
elif args.rnd_norm == 'min_max2':
max_int_rewards = jnp.max(all_int_rewards)
min_int_rewards = jnp.min(all_int_rewards)
all_int_rewards = (all_int_rewards - min_int_rewards) / (max_int_rewards - min_int_rewards)
mean_int_rewards = jnp.mean(all_int_rewards)
max_int_rewards = jnp.max(all_int_rewards)
for k in range(args.num_steps): for k in range(args.num_steps):
int_rewards = all_int_rewards[k] int_rewards = all_int_rewards[k]
storage[k] = storage[k]._replace(int_rewards=int_rewards) storage[k] = storage[k]._replace(int_rewards=int_rewards)
mean_int_rewards = jnp.mean(all_int_rewards)
max_int_rewards = jnp.max(all_int_rewards)
partitioned_storage = prepare_data(storage) partitioned_storage = prepare_data(storage)
......
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import struct import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics from ygoai.rl.env import RecordEpisodeStatistics
...@@ -28,35 +28,28 @@ def categorical_sample(logits, key): ...@@ -28,35 +28,28 @@ def categorical_sample(logits, key):
return action, key return action, key
class RunningMeanStd(struct.PyTreeNode): class RunningMeanStd:
"""Tracks the mean, variance and count of values.""" """Tracks the mean, variance and count of values."""
mean: jnp.ndarray = struct.field(pytree_node=True) # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
var: jnp.ndarray = struct.field(pytree_node=True) def __init__(self, epsilon=1e-4, shape=()):
count: jnp.ndarray = struct.field(pytree_node=True) """Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
@classmethod self.var = np.ones(shape, "float64")
def create(cls, shape=()): self.count = epsilon
# TODO: use numpy and float64
return cls(
mean=jnp.zeros(shape, "float32"),
var=jnp.ones(shape, "float32"),
count=jnp.full(shape, 1e-4, "float32"),
)
def update(self, x): def update(self, x):
"""Updates the mean, var and count from a batch of samples.""" """Updates the mean, var and count from a batch of samples."""
batch_mean = jnp.mean(x, axis=0) batch_mean = np.mean(x, axis=0)
batch_var = jnp.var(x, axis=0) batch_var = np.var(x, axis=0)
batch_count = x.shape[0] batch_count = x.shape[0]
return self.update_from_moments(batch_mean, batch_var, batch_count) self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count): def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments.""" """Updates from batch mean, variance and count moments."""
mean, var, count = update_mean_var_count_from_moments( self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count self.mean, self.var, self.count, batch_mean, batch_var, batch_count
) )
return self.replace(mean=mean, var=var, count=count)
def update_mean_var_count_from_moments( def update_mean_var_count_from_moments(
...@@ -69,7 +62,7 @@ def update_mean_var_count_from_moments( ...@@ -69,7 +62,7 @@ def update_mean_var_count_from_moments(
new_mean = mean + delta * batch_count / tot_count new_mean = mean + delta * batch_count / tot_count
m_a = var * count m_a = var * count
m_b = batch_var * batch_count m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count new_var = M2 / tot_count
new_count = tot_count new_count = tot_count
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment