Commit e1ff8f92 authored by sbl1996@126.com's avatar sbl1996@126.com

Add more rnn options and batch norm

parent 974fe861
This diff is collapsed.
...@@ -8,7 +8,7 @@ import jax.numpy as jnp ...@@ -8,7 +8,7 @@ import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, GLUMlp, BatchRenorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
...@@ -487,7 +487,7 @@ class Critic(nn.Module): ...@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, f_state): def __call__(self, f_state, train):
f_state = f_state.astype(self.dtype) f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state) x = mlp(self.channels, last_lin=False)(f_state)
...@@ -495,6 +495,33 @@ class Critic(nn.Module): ...@@ -495,6 +495,33 @@ class Critic(nn.Module):
return x return x
class CrossCritic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
# dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, train):
x = f_state.astype(self.dtype)
linear = partial(nn.Dense, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False)
BN = partial(
BatchRenorm, dtype=self.dtype, param_dtype=self.param_dtype,
momentum=self.batch_norm_momentum, axis_name="local_devices",
use_running_average=not train)
x = BN()(x)
for c in self.channels:
x = linear(c)(x)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x = nn.relu()(x)
# x = nn.leaky_relu(x, negative_slope=0.1)
x = BN()(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype)(x)
return x
class GlobalCritic(nn.Module): class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128) channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs): ...@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
oppo_info: bool = False oppo_info: bool = False
"""whether to use opponent's information""" """whether to use opponent's information"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
batch_norm: bool = False
"""whether to use batch normalization for the critic"""
critic_width: int = 128
"""the width of the critic"""
critic_depth: int = 3
"""the depth of the critic"""
rwkv_head_size: int = 32 rwkv_head_size: int = 32
"""the head size for the RWKV""" """the head size for the RWKV"""
...@@ -596,6 +631,10 @@ class RNNAgent(nn.Module): ...@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size: int = 32 rwkv_head_size: int = 32
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False oppo_info: bool = False
rnn_shortcut: bool = False
batch_norm: bool = False
critic_width: int = 128
critic_depth: int = 3
version: int = 0 version: int = 0
switch: bool = True switch: bool = True
...@@ -606,7 +645,7 @@ class RNNAgent(nn.Module): ...@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None, train=False):
batch_size = jax.tree.leaves(rstate)[0].shape[0] batch_size = jax.tree.leaves(rstate)[0].shape[0]
c = self.num_channels c = self.num_channels
...@@ -669,6 +708,10 @@ class RNNAgent(nn.Module): ...@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main( rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main) rnn_layer, rstate, f_state, done, switch_or_main)
if self.rnn_shortcut:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
if self.film: if self.film:
actor = FiLMActor( actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
...@@ -694,13 +737,16 @@ class RNNAgent(nn.Module): ...@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2) lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g) value = critic(rstate1_t, rstate2_t, f_g)
else: else:
critic = Critic( CriticCls = CrossCritic if self.batch_norm else Critic
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) cs = [self.critic_width] * self.critic_depth
value = critic(f_state_r) critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train)
if self.int_head: if self.int_head:
cs = [self.critic_width] * self.critic_depth
critic_int = Critic( critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r) value_int = critic_int(f_state_r)
value = (value, value_int) value = (value, value_int)
return rstate, logits, value, valid return rstate, logits, value, valid
......
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional, Any
import functools import functools
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from flax.linen.normalization import _compute_stats, _normalize, _canonicalize_axes
def decode_id(x): def decode_id(x):
...@@ -109,4 +110,145 @@ class RMSNorm(nn.Module): ...@@ -109,4 +110,145 @@ class RMSNorm(nn.Module):
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype "scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
) )
x = x * scale x = x * scale
return jnp.asarray(x, self.dtype) return jnp.asarray(x, self.dtype)
\ No newline at end of file
class ReZero(nn.Module):
channel_wise: bool = False
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
shape = (x.shape[-1],) if self.channel_wise else ()
scale = self.param("scale", nn.initializers.zeros, shape, self.param_dtype)
return x * scale
class BatchRenorm(nn.Module):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.999
epsilon: float = 0.001
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: nn.initializers.Initializer = nn.initializers.zeros
scale_init: nn.initializers.Initializer = nn.initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
@nn.compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = nn.merge_param(
'use_running_average', self.use_running_average, use_running_average
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]
ra_mean = self.variable(
'batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape)
ra_var = self.variable(
'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape)
r_max = self.variable('batch_stats', 'r_max', lambda s: s, 3)
d_max = self.variable('batch_stats', 'd_max', lambda s: s, 5)
steps = self.variable('batch_stats', 'steps', lambda s: s, 0)
if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
# The code below is implemented following the Batch Renormalization paper
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32)
custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var
custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean
ra_mean.value = (
self.momentum * ra_mean.value + (1 - self.momentum) * mean
)
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
steps.value += 1
return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
from typing import Any, Callable
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import core, struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
import optax
import numpy as np import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics from ygoai.rl.env import RecordEpisodeStatistics
...@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments( ...@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count = tot_count new_count = tot_count
return new_mean, new_var, new_count return new_mean, new_var, new_count
class TrainState(struct.PyTreeNode):
step: int
apply_fn: Callable = struct.field(pytree_node=False)
params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
def apply_gradients(self, *, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if OVERWRITE_WITH_GRADIENT in grads:
grads_with_opt = grads['params']
params_with_opt = self.params['params']
else:
grads_with_opt = grads
params_with_opt = self.params
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
new_params_with_opt = optax.apply_updates(params_with_opt, updates)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if OVERWRITE_WITH_GRADIENT in grads:
new_params = {
'params': new_params_with_opt,
OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
}
else:
new_params = new_params_with_opt
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt = (
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
)
opt_state = tx.init(params_with_opt)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment