Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
c2417798
Commit
c2417798
authored
May 23, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add RWKV
parent
4ef751bf
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
409 additions
and
7 deletions
+409
-7
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+33
-7
ygoai/rl/jax/modules.py
ygoai/rl/jax/modules.py
+61
-0
ygoai/rl/jax/rwkv.py
ygoai/rl/jax/rwkv.py
+315
-0
No files found.
ygoai/rl/jax/agent.py
View file @
c2417798
...
...
@@ -8,7 +8,8 @@ import jax.numpy as jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
RMSNorm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.rwkv
import
Rwkv6SelfAttention
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
...
...
@@ -153,6 +154,7 @@ class GlobalEncoder(nn.Module):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
):
...
...
@@ -196,6 +198,7 @@ class GlobalEncoder(nn.Module):
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
out_channels
:
Optional
[
int
]
=
None
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
...
...
@@ -264,8 +267,14 @@ class Encoder(nn.Module):
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
if
self
.
version
==
2
:
x_global
=
jax
.
nn
.
leaky_relu
(
x_global
,
negative_slope
=
0.1
)
x_global
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_global
)
f_global
=
x_global
+
GLUMlp
(
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
layer_norm
(
dtype
=
self
.
dtype
)(
x_global
))
else
:
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
...
...
@@ -391,7 +400,8 @@ class Encoder(nn.Module):
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
else
:
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
oc
=
self
.
out_channels
or
c
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
...
...
@@ -498,12 +508,14 @@ class ModelArgs:
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'
rwkv'
,
'
none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
version
:
int
=
0
"""the version of the environment and the agent"""
...
...
@@ -522,13 +534,16 @@ class RNNAgent(nn.Module):
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
rwkv_head_size
:
int
=
32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
oc
=
self
.
rnn_channels
if
self
.
rnn_type
==
'rwkv'
else
None
encoder
=
Encoder
(
channels
=
c
,
out_channels
=
oc
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
...
...
@@ -548,6 +563,10 @@ class RNNAgent(nn.Module):
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'rwkv'
:
num_heads
=
self
.
rnn_channels
//
self
.
rwkv_head_size
rnn_layer
=
Rwkv6SelfAttention
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
...
...
@@ -596,5 +615,12 @@ class RNNAgent(nn.Module):
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
elif
self
.
rnn_type
==
'rwkv'
:
head_size
=
self
.
rwkv_head_size
num_heads
=
self
.
rnn_channels
//
self
.
rwkv_head_size
return
(
np
.
zeros
((
batch_size
,
num_heads
*
head_size
)),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)),
)
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/modules.py
View file @
c2417798
from
typing
import
Tuple
,
Union
,
Optional
import
functools
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
...
...
@@ -51,3 +53,62 @@ class MLP(nn.Module):
if
i
<
n
-
1
or
not
self
.
last_lin
:
x
=
nn
.
leaky_relu
(
x
,
negative_slope
=
0.1
)
return
x
class
GLUMlp
(
nn
.
Module
):
intermediate_size
:
int
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
kernel_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
lecun_normal
()
last_kernel_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
lecun_normal
()
use_bias
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
inputs
):
dense
=
[
functools
.
partial
(
nn
.
DenseGeneral
,
use_bias
=
self
.
use_bias
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
self
.
kernel_init
,
bias_init
=
self
.
bias_init
,
)
for
_
in
range
(
3
)
]
actual_out_dim
=
inputs
.
shape
[
-
1
]
g
=
dense
[
0
](
features
=
self
.
intermediate_size
,
name
=
"gate"
,
)(
inputs
)
g
=
nn
.
silu
(
g
)
x
=
g
*
dense
[
1
](
features
=
self
.
intermediate_size
,
name
=
"up"
,
)(
inputs
)
x
=
dense
[
2
](
features
=
actual_out_dim
,
name
=
"down"
,
)(
x
)
return
x
class
RMSNorm
(
nn
.
Module
):
epsilon
:
float
=
1e-6
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
dtype
=
jnp
.
promote_types
(
self
.
dtype
,
jnp
.
float32
)
x
=
jnp
.
asarray
(
x
,
dtype
)
x
=
x
*
jax
.
lax
.
rsqrt
(
jnp
.
square
(
x
)
.
mean
(
-
1
,
keepdims
=
True
)
+
self
.
epsilon
)
reduced_feature_shape
=
(
x
.
shape
[
-
1
],)
scale
=
self
.
param
(
"scale"
,
nn
.
initializers
.
ones
,
reduced_feature_shape
,
self
.
param_dtype
)
x
=
x
*
scale
return
jnp
.
asarray
(
x
,
self
.
dtype
)
\ No newline at end of file
ygoai/rl/jax/rwkv.py
0 → 100644
View file @
c2417798
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment