Commit c2417798 authored by sbl1996@126.com's avatar sbl1996@126.com

Add RWKV

parent 4ef751bf
......@@ -8,7 +8,8 @@ import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
default_embed_init = nn.initializers.uniform(scale=0.001)
......@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
version: int = 0
@nn.compact
def __call__(self, x):
......@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module):
class Encoder(nn.Module):
channels: int = 128
out_channels: Optional[int] = None
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
......@@ -264,8 +267,14 @@ class Encoder(nn.Module):
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype)
if self.version == 2:
x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
......@@ -391,7 +400,8 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
......@@ -498,12 +508,14 @@ class ModelArgs:
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
version: int = 0
"""the version of the environment and the agent"""
......@@ -522,13 +534,16 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
rwkv_head_size: int = 32
version: int = 0
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
oc = self.rnn_channels if self.rnn_type == 'rwkv' else None
encoder = Encoder(
channels=c,
out_channels=oc,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
......@@ -548,6 +563,10 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'rwkv':
num_heads = self.rnn_channels // self.rwkv_head_size
rnn_layer = Rwkv6SelfAttention(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)
elif self.rnn_type is None:
rnn_layer = None
......@@ -596,5 +615,12 @@ class RNNAgent(nn.Module):
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size
return (
np.zeros((batch_size, num_heads*head_size)),
np.zeros((batch_size, num_heads*head_size*head_size)),
)
else:
return None
\ No newline at end of file
from typing import Tuple, Union, Optional
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
......@@ -51,3 +53,62 @@ class MLP(nn.Module):
if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1)
return x
class GLUMlp(nn.Module):
intermediate_size: int
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment