Commit c2417798 authored by sbl1996@126.com's avatar sbl1996@126.com

Add RWKV

parent 4ef751bf
......@@ -8,7 +8,8 @@ import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
default_embed_init = nn.initializers.uniform(scale=0.001)
......@@ -120,7 +121,7 @@ class CardEncoder(nn.Module):
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
......@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
version: int = 0
@nn.compact
def __call__(self, x):
......@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module):
class Encoder(nn.Module):
channels: int = 128
out_channels: Optional[int] = None
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
......@@ -264,10 +267,16 @@ class Encoder(nn.Module):
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
if self.version == 2:
x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
......@@ -391,7 +400,8 @@ class Encoder(nn.Module):
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
......@@ -498,12 +508,14 @@ class ModelArgs:
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
version: int = 0
"""the version of the environment and the agent"""
......@@ -522,13 +534,16 @@ class RNNAgent(nn.Module):
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
rwkv_head_size: int = 32
version: int = 0
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
oc = self.rnn_channels if self.rnn_type == 'rwkv' else None
encoder = Encoder(
channels=c,
out_channels=oc,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
......@@ -548,6 +563,10 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'rwkv':
num_heads = self.rnn_channels // self.rwkv_head_size
rnn_layer = Rwkv6SelfAttention(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)
elif self.rnn_type is None:
rnn_layer = None
......@@ -596,5 +615,12 @@ class RNNAgent(nn.Module):
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
elif self.rnn_type == 'rwkv':
head_size = self.rwkv_head_size
num_heads = self.rnn_channels // self.rwkv_head_size
return (
np.zeros((batch_size, num_heads*head_size)),
np.zeros((batch_size, num_heads*head_size*head_size)),
)
else:
return None
\ No newline at end of file
from typing import Tuple, Union, Optional
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
......@@ -51,3 +53,62 @@ class MLP(nn.Module):
if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1)
return x
class GLUMlp(nn.Module):
intermediate_size: int
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
\ No newline at end of file
from typing import Optional
from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import (
Dtype,
)
TIME_MIX_EXTRA_DIM = 32
def hf_rwkv6_linear_attention(receptance, key, value, time_decay, time_first, state):
# receptance: (seq_length, batch, num_heads*head_size)
# key: (seq_length, batch, num_heads*head_size)
# value: (seq_length, batch, num_heads*head_size)
# time_decay: (seq_length, batch, num_heads*head_size)
# time_first: (num_heads, head_size)
# state: (batch, num_heads, head_size, head_size)
# out: (seq_length, batch, num_heads, head_size)
if receptance.ndim == 2:
receptance = receptance[None]
shape = state.shape
seq_length, batch, _ = receptance.shape
num_heads, head_size = time_first.shape
key = key.reshape(seq_length, batch, num_heads, head_size)
value = value.reshape(seq_length, batch, num_heads, head_size)
receptance = receptance.reshape(seq_length, batch, num_heads, head_size)
state = state.reshape(batch, num_heads, head_size, head_size)
time_decay = jnp.exp(-jnp.exp(time_decay)).reshape(seq_length, batch, num_heads, head_size)
time_first = time_first.reshape(num_heads, head_size, 1) # h, n -> h, n, 1
def body_fn(carry, inp):
state, tf = carry
r, k, v, w = inp
a = k[:, :, :, None] @ v[:, :, None, :]
out = (r[:, :, None, :] @ (tf * a + state)).squeeze(2)
state = a + w[:, :, :, None] * state
return (state, tf), out
if seq_length == 1:
(state, _), out = body_fn((state, time_first), (receptance[0], key[0], value[0], time_decay[0]))
out = out[None, ...]
else:
(state, _), out = jax.lax.scan(body_fn, (state, time_first), (receptance, key, value, time_decay))
out = out.reshape(seq_length, batch, num_heads * head_size)
state = state.reshape(shape)
return out, state
def time_p_initilizer(ratio):
def init_fn(key, shape, dtype):
w = jnp.arange(shape[-1], dtype=dtype) / shape[-1]
p = 1.0 - jnp.power(w, ratio)
p = jnp.broadcast_to(p, shape)
return p
return init_fn
def time_decay_init(key, shape, dtype):
attention_hidden_size = shape[-1]
ratio_0_to_1 = 0
decay_speed = [
-6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
for h in range(attention_hidden_size)
]
w = jnp.array(decay_speed, dtype=dtype)
return w[None, None, :]
def time_faaaa_init(key, shape, dtype):
attention_hidden_size = shape[0] * shape[1]
ratio_0_to_1 = 0
w = [
(1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1)
for i in range(attention_hidden_size)
]
w = jnp.array(w, dtype=dtype)
return w.reshape(shape)
class Rwkv6SelfAttention(nn.Module):
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
B, C = inputs.shape
def time_w(name, shape):
return self.param(
name,
lambda key, shape, dtype: jax.random.uniform(key, shape, dtype, -1e-4, 1e-4),
shape, self.param_dtype)
dense = partial(
nn.Dense,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
shifted, state = carry
x = inputs
xx = shifted - x
shifted = inputs
time_maa_x = self.param(
"time_maa_x", time_p_initilizer(1.0), (1, C), self.param_dtype)
time_maa_w1 = time_w("time_maa_w1", (C, TIME_MIX_EXTRA_DIM * 5))
time_maa_w2 = time_w("time_maa_w2", (5, TIME_MIX_EXTRA_DIM, C))
xxx = x + xx * time_maa_x
xxx = jnp.tanh(xxx @ time_maa_w1).reshape((B, 5, -1)).transpose((1, 0, 2))
xxx = (xxx @ time_maa_w2).reshape((5, B, C))
time_maa_wkvrg = self.param(
"time_maa_wkvrg", time_p_initilizer(1.0), (5, 1, C), self.param_dtype)
x = x[None] + xx[None] * (time_maa_wkvrg + xxx)
time_decay = x[0]
rkvg = x[1:5]
w_rkvg = self.param(
"w_rkvg", nn.initializers.lecun_normal(),
(4, 1, C, C), self.param_dtype)
rkvg = rkvg[:, :, None, :] @ w_rkvg
receptance, key, value, gate = [
rkvg[i, :, 0] for i in range(4)
]
time_decay_w1 = time_w("time_decay_w1", (C, TIME_MIX_EXTRA_DIM))
time_decay_w2 = time_w("time_decay_w2", (TIME_MIX_EXTRA_DIM, C))
time_decay = jnp.tanh(time_decay @ time_decay_w1) @ time_decay_w2
time_decay_p = self.param(
"time_decay", time_decay_init, (1, C), self.param_dtype)
time_decay = time_decay_p + time_decay
time_faaaa = self.param(
"time_faaaa", time_faaaa_init, (self.num_heads, C // self.num_heads), self.param_dtype)
out, state = hf_rwkv6_linear_attention(
receptance, key, value, time_decay, time_faaaa, state,
)
out = out[0]
out = nn.GroupNorm(
num_groups=self.num_heads, epsilon=(1e-5)*(8**2))(out)
out = out * jax.nn.swish(gate)
out = dense(features=C, name="output")(out)
return (shifted, state), out
class Rwkv6SelfAttention0(nn.Module):
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
shape1 = inputs.shape
shape2 = carry[0].shape
if inputs.ndim == 2:
inputs = inputs[None, ...]
T, B, C = inputs.shape
def time_p(name, ratio=1.0):
return self.param(
name, time_p_initilizer(ratio), (1, 1, C), self.param_dtype)
def time_w(name, shape):
return self.param(
name,
lambda key, shape, dtype: jax.random.uniform(key, shape, dtype, -1e-4, 1e-4),
shape, self.param_dtype)
dense = partial(
nn.Dense,
features=C,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
shifted, state = carry
if shifted.ndim == 2:
shifted = shifted[None, ...]
if T != 1:
shifted = jnp.concatenate([
shifted, inputs[:-1, 0, :]], axis=0)
x = inputs
xx = shifted - x
shifted = inputs[-1].reshape(shape2)
xxx = x + xx * time_p('time_maa_x')
time_maa_w1 = time_w("time_maa_w1", (C, TIME_MIX_EXTRA_DIM * 5))
xxx = jnp.tanh(xxx @ time_maa_w1).reshape((T*B, 5, -1)).transpose((1, 0, 2))
time_maa_w2 = time_w("time_maa_w2", (5, TIME_MIX_EXTRA_DIM, C))
xxx = (xxx @ time_maa_w2).reshape((5, T, B, -1))
mw, mk, mv, mr, mg = [
x[0] for x in jnp.split(xxx, 5, axis=0)
]
time_decay = x + xx * (time_p("time_maa_w") + mw)
key = x + xx * (time_p("time_maa_k") + mk)
value = x + xx * (time_p("time_maa_v") + mv)
receptance = x + xx * (time_p("time_maa_r", 0.5) + mr)
gate = x + xx * (time_p("time_maa_g", 0.5) + mg)
receptance = dense(name="receptance")(receptance)
key = dense(name="key")(key)
value = dense(name="value")(value)
gate = jax.nn.swish(dense(name="gate")(gate))
time_decay_w1 = time_w("time_decay_w1", (C, TIME_MIX_EXTRA_DIM))
time_decay_w2 = time_w("time_decay_w2", (TIME_MIX_EXTRA_DIM, C))
time_decay = jnp.tanh(time_decay @ time_decay_w1) @ time_decay_w2
time_decay_p = self.param(
"time_decay", time_decay_init, (1, 1, C), self.param_dtype)
time_decay = time_decay_p + time_decay
time_faaaa = self.param(
"time_faaaa", time_faaaa_init, (self.num_heads, C // self.num_heads), self.param_dtype)
out, state = hf_rwkv6_linear_attention(
receptance, key, value, time_decay, time_faaaa, state,
)
out = nn.GroupNorm(
num_groups=self.num_heads, epsilon=(1e-5)*(8**2))(out)
out = out * gate
out = dense(name="output")(out)
out = out.reshape(shape1)
return (shifted, state), out
class Rwkv6FeedForward(nn.Module):
intermediate_size: Optional[int] = None
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
assert inputs.ndim == 3, "inputs must have shape (batch, seq_len, features)"
T, B, C = inputs.shape
def time_p(name, ratio=1.0):
return self.param(
name, time_p_initilizer(ratio), (1, 1, C), self.param_dtype)
intermediate_size = self.intermediate_size or int((C * 3.5) // 32 * 32)
dense = partial(
nn.Dense,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
_shifted, _state, shifted = carry
if shifted.ndim == 2:
shifted = shifted[None, ...]
if T != 1:
shifted = jnp.concatenate([
shifted, inputs[:-1, 0, :]], axis=0)
x = inputs
xx = shifted - x
key = x + xx * time_p('time_maa_k')
receptance = x + xx * time_p('time_maa_r', 0.5)
key = jnp.square(jax.nn.relu(
dense(features=intermediate_size,name="key")(key)))
value = dense(features=C,name="value")(key)
receptance = jax.nn.sigmoid(
dense(features=C,name="receptance")(receptance))
out = value * receptance
return (_shifted, _state, inputs[-1]), out
class Rwkv6Block(nn.Module):
num_heads: int
intermediate_size: Optional[int] = None
layer_norm_epsilon: float = 1e-5
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, carry, inputs):
layer_norm = partial(
nn.LayerNorm, epsilon=self.layer_norm_epsilon)
x = inputs
y = layer_norm(name="ln1")(x)
carry, y = Rwkv6SelfAttention(
num_heads=self.num_heads,
dtype=self.dtype,
param_dtype=self.param_dtype
)(carry, y)
x = x + y
y = layer_norm(name="ln2")(x)
carry, y = Rwkv6FeedForward(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype
)(carry, y)
x = x + y
return carry, x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment