Commit c2417798 authored by sbl1996@126.com's avatar sbl1996@126.com

Add RWKV

parent 4ef751bf
...@@ -8,7 +8,8 @@ import jax.numpy as jnp ...@@ -8,7 +8,8 @@ import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
default_embed_init = nn.initializers.uniform(scale=0.001) default_embed_init = nn.initializers.uniform(scale=0.001)
...@@ -120,7 +121,7 @@ class CardEncoder(nn.Module): ...@@ -120,7 +121,7 @@ class CardEncoder(nn.Module):
x_level = embed(14, c // 16)(x1[:, :, 7]) x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8]) x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9]) x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2]) x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk) x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4]) x_def = num_transform(x2[:, :, 2:4])
...@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module): ...@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
version: int = 0
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module): ...@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
channels: int = 128 channels: int = 128
out_channels: Optional[int] = None
num_layers: int = 2 num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -264,10 +267,16 @@ class Encoder(nn.Module): ...@@ -264,10 +267,16 @@ class Encoder(nn.Module):
# Global # Global
x_global = GlobalEncoder( x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype) x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global) if self.version == 2:
f_global = fc_layer(c, dtype=self.dtype)(f_global) x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global) f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions # History actions
...@@ -391,7 +400,8 @@ class Encoder(nn.Module): ...@@ -391,7 +400,8 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else: else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid return f_actions, f_state, a_mask, valid
...@@ -498,12 +508,14 @@ class ModelArgs: ...@@ -498,12 +508,14 @@ class ModelArgs:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
card_mask: bool = False card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer""" """whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm" rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = False
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
noam: bool = False noam: bool = False
"""whether to use Noam architecture for the transformer layer""" """whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
version: int = 0 version: int = 0
"""the version of the environment and the agent""" """the version of the environment and the agent"""
...@@ -522,13 +534,16 @@ class RNNAgent(nn.Module): ...@@ -522,13 +534,16 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm' rnn_type: str = 'lstm'
film: bool = False film: bool = False
noam: bool = False noam: bool = False
rwkv_head_size: int = 32
version: int = 0 version: int = 0
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels c = self.num_channels
oc = self.rnn_channels if self.rnn_type == 'rwkv' else None
encoder = Encoder( encoder = Encoder(
channels=c, channels=c,
out_channels=oc,
num_layers=self.num_layers, num_layers=self.num_layers,
embedding_shape=self.embedding_shape, embedding_shape=self.embedding_shape,
dtype=self.dtype, dtype=self.dtype,
...@@ -548,6 +563,10 @@ class RNNAgent(nn.Module): ...@@ -548,6 +563,10 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'gru': elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell( rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0)) self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'rwkv':
num_heads = self.rnn_channels // self.rwkv_head_size
rnn_layer = Rwkv6SelfAttention(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)
elif self.rnn_type is None: elif self.rnn_type is None:
rnn_layer = None rnn_layer = None
...@@ -596,5 +615,12 @@ class RNNAgent(nn.Module): ...@@ -596,5 +615,12 @@ class RNNAgent(nn.Module):
) )
elif self.rnn_type == 'gru': elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels)) return np.zeros((batch_size, self.rnn_channels))
elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size
return (
np.zeros((batch_size, num_heads*head_size)),
np.zeros((batch_size, num_heads*head_size*head_size)),
)
else: else:
return None return None
\ No newline at end of file
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional
import functools
import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
...@@ -51,3 +53,62 @@ class MLP(nn.Module): ...@@ -51,3 +53,62 @@ class MLP(nn.Module):
if i < n - 1 or not self.last_lin: if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1) x = nn.leaky_relu(x, negative_slope=0.1)
return x return x
class GLUMlp(nn.Module):
intermediate_size: int
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment