Commit 4c0edbf8 authored by sbl1996@126.com's avatar sbl1996@126.com

Add RND

parent d43e5903
This diff is collapsed.
This diff is collapsed.
......@@ -107,15 +107,15 @@ def get_from_action(values, action):
return jnp.sum(distrax.multiply_no_nan(values, value_one_hot), axis=-1)
def mean_legal(values, axis=None):
def mean_legal(values, axis=None, keepdims=False):
# TODO: use real action mask
no_nan_mask = values > -1e12
no_nan = jnp.where(no_nan_mask, values, 0)
count = jnp.sum(no_nan_mask, axis=axis)
return jnp.sum(no_nan, axis=axis) / jnp.maximum(count, 1)
count = jnp.sum(no_nan_mask, axis=axis, keepdims=keepdims)
return jnp.sum(no_nan, axis=axis, keepdims=keepdims) / jnp.maximum(count, 1)
def neurd_loss(actions, logits, new_logits, advantages, logits_threshold):
def neurd_loss_2(actions, logits, new_logits, advantages, logits_threshold):
# Neural Replicator Dynamics
# Differences from the original implementation:
# - all actions vs. sampled actions
......@@ -136,6 +136,27 @@ def neurd_loss(actions, logits, new_logits, advantages, logits_threshold):
return pg_loss
def neurd_loss(new_logits, advantages, logits_threshold=2.0, adv_threshold=1000.0):
advs = jax.lax.stop_gradient(advantages)
legal_mask = new_logits > -1e12
legal_logits = jnp.where(legal_mask, new_logits, 0)
count = jnp.sum(legal_mask, axis=-1, keepdims=True)
new_logits_ = new_logits - jnp.sum(legal_logits, axis=-1, keepdims=True) / jnp.maximum(count, 1)
can_increase = new_logits_ < logits_threshold
can_decrease = new_logits_ > -logits_threshold
c = jnp.where(
advs >= 0, can_increase, can_decrease).astype(jnp.float32)
c = jax.lax.stop_gradient(c)
advs = jnp.clip(advs, -adv_threshold, adv_threshold)
# TODO: renormalize with player
pg_loss = -c * new_logits_ * advs
pg_loss = jnp.where(legal_mask, pg_loss, 0)
pg_loss = jnp.sum(pg_loss, axis=-1)
return pg_loss
def ach_loss(actions, logits, new_logits, advantages, logits_threshold, clip_coef, dual_clip_coef=None):
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# notice entropy term is required but not included here
......@@ -158,7 +179,83 @@ def ach_loss(actions, logits, new_logits, advantages, logits_threshold, clip_coe
return pg_loss
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
def vtrace_rnad_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, reward_u1, reward_u2 = carry
ratio, cur_values, r_t, eta_reg_entropy, probs, a_t, eta_log_policy, next_done, main = inp
v1 = jnp.where(next_done, 0, v1)
v2 = jnp.where(next_done, 0, v2)
next_values1 = jnp.where(next_done, 0, next_values1)
next_values2 = jnp.where(next_done, 0, next_values2)
reward1 = jnp.where(next_done, 0, reward1)
reward2 = jnp.where(next_done, 0, reward2)
xi1 = jnp.where(next_done, 1, xi1)
xi2 = jnp.where(next_done, 1, xi2)
reward_u1 = jnp.where(next_done, 0, reward_u1)
reward_u2 = jnp.where(next_done, 0, reward_u2)
discount = gamma * (1.0 - next_done)
next_v = jnp.where(main, v1, v2)
next_values = jnp.where(main, next_values1, next_values2)
reward = jnp.where(main, reward1, reward2)
xi = jnp.where(main, xi1, xi2)
reward_u = jnp.where(main, reward_u1, reward_u2)
reward_u = r_t + discount * reward_u + eta_reg_entropy
discounted_reward = r_t + discount * reward
rho_t = jnp.clip(ratio * xi, rho_min, rho_max)
c_t = jnp.clip(ratio * xi, c_min, c_max)
sig_v = rho_t * (reward_u + discount * next_values - cur_values)
v = cur_values + sig_v + c_t * discount * (next_v - next_values)
q_t = cur_values[:, None] + eta_log_policy
n_actions = eta_log_policy.shape[-1]
q_t2 = discounted_reward + discount * xi * next_v - cur_values
q_t = q_t + q_t2[:, None] * distrax.multiply_no_nan(
1.0 / jnp.maximum(probs, 1e-3), jax.nn.one_hot(a_t, n_actions))
v1 = jnp.where(main, v, discount * v1)
v2 = jnp.where(main, discount * v2, v)
next_values1 = jnp.where(main, cur_values, discount * next_values1)
next_values2 = jnp.where(main, discount * next_values2, cur_values)
reward1 = jnp.where(main, 0, ratio * (discount * reward1 - r_t) - eta_reg_entropy)
reward2 = jnp.where(main, ratio * (discount * reward2 - r_t) - eta_reg_entropy, 0)
xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1)
reward_u1 = jnp.where(main, 0, discount * reward_u1 - r_t - eta_reg_entropy)
reward_u2 = jnp.where(main, discount * reward_u2 - r_t - eta_reg_entropy, 0)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, reward_u1, reward_u2
return carry, (v, q_t)
def vtrace_rnad(
next_value, ratios, logits, new_logits, actions,
log_policy_reg, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, eta=0.2,
):
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
eta_reg_entropy = -eta * jnp.sum(new_probs * log_policy_reg, axis=-1)
eta_log_policy = -eta * log_policy_reg
next_value1 = next_value
next_value2 = -next_value1
v1 = next_value1
v2 = next_value2
reward1 = reward2 = reward_u1 = reward_u2 = jnp.zeros_like(next_value)
xi1 = xi2 = jnp.ones_like(next_value)
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, reward_u1, reward_u2
_, (targets, q_estimate) = jax.lax.scan(
partial(vtrace_rnad_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, rewards, eta_reg_entropy, probs, actions, eta_log_policy, next_dones, mains), reverse=True
)
targets = jax.lax.stop_gradient(targets)
return targets, q_estimate
def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
ratio, cur_values, next_done, r_t, main = inp
......@@ -229,7 +326,7 @@ def vtrace_2p0s(
return1, return2, next_q1, next_q2
_, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
partial(vtrace_2p0s_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
advantages = q_estimate - values
......@@ -314,6 +411,29 @@ def truncated_gae_2p0s(
return targets, advantages
def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value = carry
cur_value, next_done, reward = inp
nextnonterminal = 1.0 - next_done
delta = reward + gamma * next_value * nextnonterminal - cur_value
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
carry = lastgaelam, cur_value
return carry, lastgaelam
def truncated_gae(next_value, values, rewards, next_dones, gamma, gae_lambda):
lastgaelam = jnp.zeros_like(next_value)
carry = lastgaelam, next_value
_, advantages = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards), reverse=True
)
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def simple_policy_loss(ratios, logits, new_logits, advantages, kld_max, eps=1e-12):
advs = jax.lax.stop_gradient(advantages)
probs = jax.nn.softmax(logits)
......
......@@ -506,40 +506,39 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
@dataclass
class ModelArgs:
class EncoderArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
"""the version of the environment and the agent"""
@dataclass
class ModelArgs(EncoderArgs):
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
class RNNAgent(nn.Module):
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
......@@ -549,6 +548,13 @@ class RNNAgent(nn.Module):
action_feats: bool = True
version: int = 0
switch: bool = True
freeze_id: bool = False
int_head: bool = False
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
......@@ -618,6 +624,11 @@ class RNNAgent(nn.Module):
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
if self.int_head:
critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r)
value = (value, value_int)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
......@@ -636,4 +647,62 @@ class RNNAgent(nn.Module):
np.zeros((batch_size, num_heads*head_size*head_size)),
)
else:
return None
\ No newline at end of file
return None
default_rnd_args = EncoderArgs(
num_layers=1,
num_channels=128,
use_history=True,
card_mask=False,
noam=True,
action_feats=True,
version=2,
)
class RNDModel(nn.Module):
is_predictor: bool = False
num_layers: int = 1
num_channels: int = 128
use_history: bool = True
card_mask: bool = False
noam: bool = True
action_feats: bool = True
version: int = 2
out_channels: Optional[int] = None
freeze_id: bool = True
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.num_channels
oc = self.out_channels or c * 2
encoder = Encoder(
channels=c,
out_channels=oc,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
action_feats=self.action_feats,
version=self.version,
)
f_actions, f_state, mask, valid = encoder(x)
c = f_state.shape[-1]
if self.is_predictor:
predictor = MLP([oc, oc], dtype=self.dtype, param_dtype=self.param_dtype)
f_state = predictor(f_state)
else:
f_state = nn.Dense(
oc, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(np.sqrt(2)))(f_state)
return f_state
\ No newline at end of file
import jax
import jax.numpy as jnp
from flax import struct
from ygoai.rl.env import RecordEpisodeStatistics
......@@ -24,3 +26,50 @@ def categorical_sample(logits, key):
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=-1)
return action, key
class RunningMeanStd(struct.PyTreeNode):
"""Tracks the mean, variance and count of values."""
mean: jnp.ndarray = struct.field(pytree_node=True)
var: jnp.ndarray = struct.field(pytree_node=True)
count: jnp.ndarray = struct.field(pytree_node=True)
@classmethod
def create(cls, shape=()):
return cls(
mean=jnp.zeros(shape, "float64"),
var=jnp.ones(shape, "float64"),
count=jnp.full(shape, 1e-4, "float64"),
)
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = jnp.mean(x, axis=0)
batch_var = jnp.var(x, axis=0)
batch_count = x.shape[0]
return self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
mean, var, count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
return self.replace(mean=mean, var=var, count=count)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment