Commit c2417798 authored by sbl1996@126.com's avatar sbl1996@126.com

Add RWKV

parent 4ef751bf
...@@ -8,7 +8,8 @@ import jax.numpy as jnp ...@@ -8,7 +8,8 @@ import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
default_embed_init = nn.initializers.uniform(scale=0.001) default_embed_init = nn.initializers.uniform(scale=0.001)
...@@ -120,7 +121,7 @@ class CardEncoder(nn.Module): ...@@ -120,7 +121,7 @@ class CardEncoder(nn.Module):
x_level = embed(14, c // 16)(x1[:, :, 7]) x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8]) x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9]) x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2]) x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk) x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4]) x_def = num_transform(x2[:, :, 2:4])
...@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module): ...@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
version: int = 0
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module): ...@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
channels: int = 128 channels: int = 128
out_channels: Optional[int] = None
num_layers: int = 2 num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -264,10 +267,16 @@ class Encoder(nn.Module): ...@@ -264,10 +267,16 @@ class Encoder(nn.Module):
# Global # Global
x_global = GlobalEncoder( x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype) x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global) if self.version == 2:
f_global = fc_layer(c, dtype=self.dtype)(f_global) x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global) f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions # History actions
...@@ -391,7 +400,8 @@ class Encoder(nn.Module): ...@@ -391,7 +400,8 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else: else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1) f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid return f_actions, f_state, a_mask, valid
...@@ -498,12 +508,14 @@ class ModelArgs: ...@@ -498,12 +508,14 @@ class ModelArgs:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
card_mask: bool = False card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer""" """whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm" rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = False
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
noam: bool = False noam: bool = False
"""whether to use Noam architecture for the transformer layer""" """whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
version: int = 0 version: int = 0
"""the version of the environment and the agent""" """the version of the environment and the agent"""
...@@ -522,13 +534,16 @@ class RNNAgent(nn.Module): ...@@ -522,13 +534,16 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm' rnn_type: str = 'lstm'
film: bool = False film: bool = False
noam: bool = False noam: bool = False
rwkv_head_size: int = 32
version: int = 0 version: int = 0
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels c = self.num_channels
oc = self.rnn_channels if self.rnn_type == 'rwkv' else None
encoder = Encoder( encoder = Encoder(
channels=c, channels=c,
out_channels=oc,
num_layers=self.num_layers, num_layers=self.num_layers,
embedding_shape=self.embedding_shape, embedding_shape=self.embedding_shape,
dtype=self.dtype, dtype=self.dtype,
...@@ -548,6 +563,10 @@ class RNNAgent(nn.Module): ...@@ -548,6 +563,10 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'gru': elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell( rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0)) self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'rwkv':
num_heads = self.rnn_channels // self.rwkv_head_size
rnn_layer = Rwkv6SelfAttention(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)
elif self.rnn_type is None: elif self.rnn_type is None:
rnn_layer = None rnn_layer = None
...@@ -596,5 +615,12 @@ class RNNAgent(nn.Module): ...@@ -596,5 +615,12 @@ class RNNAgent(nn.Module):
) )
elif self.rnn_type == 'gru': elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels)) return np.zeros((batch_size, self.rnn_channels))
elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size
return (
np.zeros((batch_size, num_heads*head_size)),
np.zeros((batch_size, num_heads*head_size*head_size)),
)
else: else:
return None return None
\ No newline at end of file
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional
import functools
import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
...@@ -51,3 +53,62 @@ class MLP(nn.Module): ...@@ -51,3 +53,62 @@ class MLP(nn.Module):
if i < n - 1 or not self.last_lin: if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1) x = nn.leaky_relu(x, negative_slope=0.1)
return x return x
class GLUMlp(nn.Module):
intermediate_size: int
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
\ No newline at end of file
from typing import Optional
from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import (
Dtype,
)
TIME_MIX_EXTRA_DIM = 32
def hf_rwkv6_linear_attention(receptance, key, value, time_decay, time_first, state):
# receptance: (seq_length, batch, num_heads*head_size)
# key: (seq_length, batch, num_heads*head_size)
# value: (seq_length, batch, num_heads*head_size)
# time_decay: (seq_length, batch, num_heads*head_size)
# time_first: (num_heads, head_size)
# state: (batch, num_heads, head_size, head_size)
# out: (seq_length, batch, num_heads, head_size)
if receptance.ndim == 2:
receptance = receptance[None]
shape = state.shape
seq_length, batch, _ = receptance.shape
num_heads, head_size = time_first.shape
key = key.reshape(seq_length, batch, num_heads, head_size)
value = value.reshape(seq_length, batch, num_heads, head_size)
receptance = receptance.reshape(seq_length, batch, num_heads, head_size)
state = state.reshape(batch, num_heads, head_size, head_size)
time_decay = jnp.exp(-jnp.exp(time_decay)).reshape(seq_length, batch, num_heads, head_size)
time_first = time_first.reshape(num_heads, head_size, 1) # h, n -> h, n, 1
def body_fn(carry, inp):
state, tf = carry
r, k, v, w = inp
a = k[:, :, :, None] @ v[:, :, None, :]
out = (r[:, :, None, :] @ (tf * a + state)).squeeze(2)
state = a + w[:, :, :, None] * state
return (state, tf), out
if seq_length == 1:
(state, _), out = body_fn((state, time_first), (receptance[0], key[0], value[0], time_decay[0]))
out = out[None, ...]
else:
(state, _), out = jax.lax.scan(body_fn, (state, time_first), (receptance, key, value, time_decay))
out = out.reshape(seq_length, batch, num_heads * head_size)
state = state.reshape(shape)
return out, state
def time_p_initilizer(ratio):
def init_fn(key, shape, dtype):
w = jnp.arange(shape[-1], dtype=dtype) / shape[-1]
p = 1.0 - jnp.power(w, ratio)
p = jnp.broadcast_to(p, shape)
return p
return init_fn
def time_decay_init(key, shape, dtype):
attention_hidden_size = shape[-1]
ratio_0_to_1 = 0
decay_speed = [
-6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
for h in range(attention_hidden_size)
]
w = jnp.array(decay_speed, dtype=dtype)
return w[None, None, :]
def time_faaaa_init(key, shape, dtype):
attention_hidden_size = shape[0] * shape[1]
ratio_0_to_1 = 0
w = [
(1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1)
for i in range(attention_hidden_size)
]
w = jnp.array(w, dtype=dtype)
return w.reshape(shape)
class Rwkv6SelfAttention(nn.Module):
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
B, C = inputs.shape
def time_w(name, shape):
return self.param(
name,
lambda key, shape, dtype: jax.random.uniform(key, shape, dtype, -1e-4, 1e-4),
shape, self.param_dtype)
dense = partial(
nn.Dense,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
shifted, state = carry
x = inputs
xx = shifted - x
shifted = inputs
time_maa_x = self.param(
"time_maa_x", time_p_initilizer(1.0), (1, C), self.param_dtype)
time_maa_w1 = time_w("time_maa_w1", (C, TIME_MIX_EXTRA_DIM * 5))
time_maa_w2 = time_w("time_maa_w2", (5, TIME_MIX_EXTRA_DIM, C))
xxx = x + xx * time_maa_x
xxx = jnp.tanh(xxx @ time_maa_w1).reshape((B, 5, -1)).transpose((1, 0, 2))
xxx = (xxx @ time_maa_w2).reshape((5, B, C))
time_maa_wkvrg = self.param(
"time_maa_wkvrg", time_p_initilizer(1.0), (5, 1, C), self.param_dtype)
x = x[None] + xx[None] * (time_maa_wkvrg + xxx)
time_decay = x[0]
rkvg = x[1:5]
w_rkvg = self.param(
"w_rkvg", nn.initializers.lecun_normal(),
(4, 1, C, C), self.param_dtype)
rkvg = rkvg[:, :, None, :] @ w_rkvg
receptance, key, value, gate = [
rkvg[i, :, 0] for i in range(4)
]
time_decay_w1 = time_w("time_decay_w1", (C, TIME_MIX_EXTRA_DIM))
time_decay_w2 = time_w("time_decay_w2", (TIME_MIX_EXTRA_DIM, C))
time_decay = jnp.tanh(time_decay @ time_decay_w1) @ time_decay_w2
time_decay_p = self.param(
"time_decay", time_decay_init, (1, C), self.param_dtype)
time_decay = time_decay_p + time_decay
time_faaaa = self.param(
"time_faaaa", time_faaaa_init, (self.num_heads, C // self.num_heads), self.param_dtype)
out, state = hf_rwkv6_linear_attention(
receptance, key, value, time_decay, time_faaaa, state,
)
out = out[0]
out = nn.GroupNorm(
num_groups=self.num_heads, epsilon=(1e-5)*(8**2))(out)
out = out * jax.nn.swish(gate)
out = dense(features=C, name="output")(out)
return (shifted, state), out
class Rwkv6SelfAttention0(nn.Module):
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
shape1 = inputs.shape
shape2 = carry[0].shape
if inputs.ndim == 2:
inputs = inputs[None, ...]
T, B, C = inputs.shape
def time_p(name, ratio=1.0):
return self.param(
name, time_p_initilizer(ratio), (1, 1, C), self.param_dtype)
def time_w(name, shape):
return self.param(
name,
lambda key, shape, dtype: jax.random.uniform(key, shape, dtype, -1e-4, 1e-4),
shape, self.param_dtype)
dense = partial(
nn.Dense,
features=C,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
shifted, state = carry
if shifted.ndim == 2:
shifted = shifted[None, ...]
if T != 1:
shifted = jnp.concatenate([
shifted, inputs[:-1, 0, :]], axis=0)
x = inputs
xx = shifted - x
shifted = inputs[-1].reshape(shape2)
xxx = x + xx * time_p('time_maa_x')
time_maa_w1 = time_w("time_maa_w1", (C, TIME_MIX_EXTRA_DIM * 5))
xxx = jnp.tanh(xxx @ time_maa_w1).reshape((T*B, 5, -1)).transpose((1, 0, 2))
time_maa_w2 = time_w("time_maa_w2", (5, TIME_MIX_EXTRA_DIM, C))
xxx = (xxx @ time_maa_w2).reshape((5, T, B, -1))
mw, mk, mv, mr, mg = [
x[0] for x in jnp.split(xxx, 5, axis=0)
]
time_decay = x + xx * (time_p("time_maa_w") + mw)
key = x + xx * (time_p("time_maa_k") + mk)
value = x + xx * (time_p("time_maa_v") + mv)
receptance = x + xx * (time_p("time_maa_r", 0.5) + mr)
gate = x + xx * (time_p("time_maa_g", 0.5) + mg)
receptance = dense(name="receptance")(receptance)
key = dense(name="key")(key)
value = dense(name="value")(value)
gate = jax.nn.swish(dense(name="gate")(gate))
time_decay_w1 = time_w("time_decay_w1", (C, TIME_MIX_EXTRA_DIM))
time_decay_w2 = time_w("time_decay_w2", (TIME_MIX_EXTRA_DIM, C))
time_decay = jnp.tanh(time_decay @ time_decay_w1) @ time_decay_w2
time_decay_p = self.param(
"time_decay", time_decay_init, (1, 1, C), self.param_dtype)
time_decay = time_decay_p + time_decay
time_faaaa = self.param(
"time_faaaa", time_faaaa_init, (self.num_heads, C // self.num_heads), self.param_dtype)
out, state = hf_rwkv6_linear_attention(
receptance, key, value, time_decay, time_faaaa, state,
)
out = nn.GroupNorm(
num_groups=self.num_heads, epsilon=(1e-5)*(8**2))(out)
out = out * gate
out = dense(name="output")(out)
out = out.reshape(shape1)
return (shifted, state), out
class Rwkv6FeedForward(nn.Module):
intermediate_size: Optional[int] = None
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
assert inputs.ndim == 3, "inputs must have shape (batch, seq_len, features)"
T, B, C = inputs.shape
def time_p(name, ratio=1.0):
return self.param(
name, time_p_initilizer(ratio), (1, 1, C), self.param_dtype)
intermediate_size = self.intermediate_size or int((C * 3.5) // 32 * 32)
dense = partial(
nn.Dense,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
_shifted, _state, shifted = carry
if shifted.ndim == 2:
shifted = shifted[None, ...]
if T != 1:
shifted = jnp.concatenate([
shifted, inputs[:-1, 0, :]], axis=0)
x = inputs
xx = shifted - x
key = x + xx * time_p('time_maa_k')
receptance = x + xx * time_p('time_maa_r', 0.5)
key = jnp.square(jax.nn.relu(
dense(features=intermediate_size,name="key")(key)))
value = dense(features=C,name="value")(key)
receptance = jax.nn.sigmoid(
dense(features=C,name="receptance")(receptance))
out = value * receptance
return (_shifted, _state, inputs[-1]), out
class Rwkv6Block(nn.Module):
num_heads: int
intermediate_size: Optional[int] = None
layer_norm_epsilon: float = 1e-5
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
layer_norm = partial(
nn.LayerNorm, epsilon=self.layer_norm_epsilon)
x = inputs
y = layer_norm(name="ln1")(x)
carry, y = Rwkv6SelfAttention(
num_heads=self.num_heads,
dtype=self.dtype,
param_dtype=self.param_dtype
)(carry, y)
x = x + y
y = layer_norm(name="ln2")(x)
carry, y = Rwkv6FeedForward(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype
)(carry, y)
x = x + y
return carry, x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment