Commit e1ff8f92 authored by sbl1996@126.com's avatar sbl1996@126.com

Add more rnn options and batch norm

parent 974fe861
This diff is collapsed.
......@@ -8,7 +8,7 @@ import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.modules import MLP, GLUMlp, BatchRenorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
......@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
def __call__(self, f_state, train):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
......@@ -495,6 +495,33 @@ class Critic(nn.Module):
return x
class CrossCritic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
# dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, train):
x = f_state.astype(self.dtype)
linear = partial(nn.Dense, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False)
BN = partial(
BatchRenorm, dtype=self.dtype, param_dtype=self.param_dtype,
momentum=self.batch_norm_momentum, axis_name="local_devices",
use_running_average=not train)
x = BN()(x)
for c in self.channels:
x = linear(c)(x)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x = nn.relu()(x)
# x = nn.leaky_relu(x, negative_slope=0.1)
x = BN()(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype)(x)
return x
class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None
......@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor"""
oppo_info: bool = False
"""whether to use opponent's information"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
batch_norm: bool = False
"""whether to use batch normalization for the critic"""
critic_width: int = 128
"""the width of the critic"""
critic_depth: int = 3
"""the depth of the critic"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
......@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size: int = 32
action_feats: bool = True
oppo_info: bool = False
rnn_shortcut: bool = False
batch_norm: bool = False
critic_width: int = 128
critic_depth: int = 3
version: int = 0
switch: bool = True
......@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
def __call__(self, x, rstate, done=None, switch_or_main=None, train=False):
batch_size = jax.tree.leaves(rstate)[0].shape[0]
c = self.num_channels
......@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.rnn_shortcut:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
......@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g)
else:
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r)
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train)
if self.int_head:
cs = [self.critic_width] * self.critic_depth
critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r)
value = (value, value_int)
return rstate, logits, value, valid
......
from typing import Tuple, Union, Optional
from typing import Tuple, Union, Optional, Any
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.normalization import _compute_stats, _normalize, _canonicalize_axes
def decode_id(x):
......@@ -109,4 +110,145 @@ class RMSNorm(nn.Module):
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
\ No newline at end of file
return jnp.asarray(x, self.dtype)
class ReZero(nn.Module):
channel_wise: bool = False
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
shape = (x.shape[-1],) if self.channel_wise else ()
scale = self.param("scale", nn.initializers.zeros, shape, self.param_dtype)
return x * scale
class BatchRenorm(nn.Module):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.999
epsilon: float = 0.001
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: nn.initializers.Initializer = nn.initializers.zeros
scale_init: nn.initializers.Initializer = nn.initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
@nn.compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = nn.merge_param(
'use_running_average', self.use_running_average, use_running_average
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]
ra_mean = self.variable(
'batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape)
ra_var = self.variable(
'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape)
r_max = self.variable('batch_stats', 'r_max', lambda s: s, 3)
d_max = self.variable('batch_stats', 'd_max', lambda s: s, 5)
steps = self.variable('batch_stats', 'steps', lambda s: s, 0)
if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
# The code below is implemented following the Batch Renormalization paper
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32)
custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var
custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean
ra_mean.value = (
self.momentum * ra_mean.value + (1 - self.momentum) * mean
)
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
steps.value += 1
return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
from typing import Any, Callable
import jax
import jax.numpy as jnp
from flax import core, struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
import optax
import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics
......@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count = tot_count
return new_mean, new_var, new_count
class TrainState(struct.PyTreeNode):
step: int
apply_fn: Callable = struct.field(pytree_node=False)
params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
def apply_gradients(self, *, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if OVERWRITE_WITH_GRADIENT in grads:
grads_with_opt = grads['params']
params_with_opt = self.params['params']
else:
grads_with_opt = grads
params_with_opt = self.params
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
new_params_with_opt = optax.apply_updates(params_with_opt, updates)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if OVERWRITE_WITH_GRADIENT in grads:
new_params = {
'params': new_params_with_opt,
OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
}
else:
new_params = new_params_with_opt
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt = (
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
)
opt_state = tx.init(params_with_opt)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment