Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
e1ff8f92
Commit
e1ff8f92
authored
Jun 13, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add more rnn options and batch norm
parent
974fe861
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
373 additions
and
83 deletions
+373
-83
scripts/cleanba.py
scripts/cleanba.py
+99
-73
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+54
-8
ygoai/rl/jax/modules.py
ygoai/rl/jax/modules.py
+144
-2
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+76
-0
No files found.
scripts/cleanba.py
View file @
e1ff8f92
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/agent.py
View file @
e1ff8f92
...
...
@@ -8,7 +8,7 @@ import jax.numpy as jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
RMSN
orm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
BatchRen
orm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.rwkv
import
Rwkv6SelfAttention
...
...
@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
):
def
__call__
(
self
,
f_state
,
train
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
...
...
@@ -495,6 +495,33 @@ class Critic(nn.Module):
return
x
class
CrossCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
# dropout_rate: Optional[float] = None
batch_norm_momentum
:
float
=
0.99
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
train
):
x
=
f_state
.
astype
(
self
.
dtype
)
linear
=
partial
(
nn
.
Dense
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
use_bias
=
False
)
BN
=
partial
(
BatchRenorm
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
momentum
=
self
.
batch_norm_momentum
,
axis_name
=
"local_devices"
,
use_running_average
=
not
train
)
x
=
BN
()(
x
)
for
c
in
self
.
channels
:
x
=
linear
(
c
)(
x
)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x
=
nn
.
relu
()(
x
)
# x = nn.leaky_relu(x, negative_slope=0.1)
x
=
BN
()(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x
)
return
x
class
GlobalCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
...
...
@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor"""
oppo_info
:
bool
=
False
"""whether to use opponent's information"""
rnn_shortcut
:
bool
=
False
"""whether to use shortcut for the RNN"""
batch_norm
:
bool
=
False
"""whether to use batch normalization for the critic"""
critic_width
:
int
=
128
"""the width of the critic"""
critic_depth
:
int
=
3
"""the depth of the critic"""
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
...
...
@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size
:
int
=
32
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
rnn_shortcut
:
bool
=
False
batch_norm
:
bool
=
False
critic_width
:
int
=
128
critic_depth
:
int
=
3
version
:
int
=
0
switch
:
bool
=
True
...
...
@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
,
train
=
False
):
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
c
=
self
.
num_channels
...
...
@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
rnn_shortcut
:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r
=
jnp
.
concatenate
([
f_state
,
f_state_r
],
axis
=-
1
)
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
...
...
@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
else
:
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
)
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
if
self
.
int_head
:
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic_int
=
Critic
(
channels
=
[
c
,
c
,
c
]
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value_int
=
critic_int
(
f_state_r
)
value
=
(
value
,
value_int
)
return
rstate
,
logits
,
value
,
valid
...
...
ygoai/rl/jax/modules.py
View file @
e1ff8f92
from
typing
import
Tuple
,
Union
,
Optional
from
typing
import
Tuple
,
Union
,
Optional
,
Any
import
functools
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
from
flax.linen.normalization
import
_compute_stats
,
_normalize
,
_canonicalize_axes
def
decode_id
(
x
):
...
...
@@ -110,3 +111,144 @@ class RMSNorm(nn.Module):
)
x
=
x
*
scale
return
jnp
.
asarray
(
x
,
self
.
dtype
)
class
ReZero
(
nn
.
Module
):
channel_wise
:
bool
=
False
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
shape
=
(
x
.
shape
[
-
1
],)
if
self
.
channel_wise
else
()
scale
=
self
.
param
(
"scale"
,
nn
.
initializers
.
zeros
,
shape
,
self
.
param_dtype
)
return
x
*
scale
class
BatchRenorm
(
nn
.
Module
):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average
:
Optional
[
bool
]
=
None
axis
:
int
=
-
1
momentum
:
float
=
0.999
epsilon
:
float
=
0.001
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
use_bias
:
bool
=
True
use_scale
:
bool
=
True
bias_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
zeros
scale_init
:
nn
.
initializers
.
Initializer
=
nn
.
initializers
.
ones
axis_name
:
Optional
[
str
]
=
None
axis_index_groups
:
Any
=
None
use_fast_variance
:
bool
=
True
@
nn
.
compact
def
__call__
(
self
,
x
,
use_running_average
:
Optional
[
bool
]
=
None
):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average
=
nn
.
merge_param
(
'use_running_average'
,
self
.
use_running_average
,
use_running_average
)
feature_axes
=
_canonicalize_axes
(
x
.
ndim
,
self
.
axis
)
reduction_axes
=
tuple
(
i
for
i
in
range
(
x
.
ndim
)
if
i
not
in
feature_axes
)
feature_shape
=
[
x
.
shape
[
ax
]
for
ax
in
feature_axes
]
ra_mean
=
self
.
variable
(
'batch_stats'
,
'mean'
,
lambda
s
:
jnp
.
zeros
(
s
,
jnp
.
float32
),
feature_shape
)
ra_var
=
self
.
variable
(
'batch_stats'
,
'var'
,
lambda
s
:
jnp
.
ones
(
s
,
jnp
.
float32
),
feature_shape
)
r_max
=
self
.
variable
(
'batch_stats'
,
'r_max'
,
lambda
s
:
s
,
3
)
d_max
=
self
.
variable
(
'batch_stats'
,
'd_max'
,
lambda
s
:
s
,
5
)
steps
=
self
.
variable
(
'batch_stats'
,
'steps'
,
lambda
s
:
s
,
0
)
if
use_running_average
:
mean
,
var
=
ra_mean
.
value
,
ra_var
.
value
custom_mean
=
mean
custom_var
=
var
else
:
mean
,
var
=
_compute_stats
(
x
,
reduction_axes
,
dtype
=
self
.
dtype
,
axis_name
=
self
.
axis_name
if
not
self
.
is_initializing
()
else
None
,
axis_index_groups
=
self
.
axis_index_groups
,
use_fast_variance
=
self
.
use_fast_variance
,
)
custom_mean
=
mean
custom_var
=
var
if
not
self
.
is_initializing
():
# The code below is implemented following the Batch Renormalization paper
r
=
1
d
=
0
std
=
jnp
.
sqrt
(
var
+
self
.
epsilon
)
ra_std
=
jnp
.
sqrt
(
ra_var
.
value
+
self
.
epsilon
)
r
=
jax
.
lax
.
stop_gradient
(
std
/
ra_std
)
r
=
jnp
.
clip
(
r
,
1
/
r_max
.
value
,
r_max
.
value
)
d
=
jax
.
lax
.
stop_gradient
((
mean
-
ra_mean
.
value
)
/
ra_std
)
d
=
jnp
.
clip
(
d
,
-
d_max
.
value
,
d_max
.
value
)
tmp_var
=
var
/
(
r
**
2
)
tmp_mean
=
mean
-
d
*
jnp
.
sqrt
(
custom_var
)
/
r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up
=
jnp
.
greater_equal
(
steps
.
value
,
100_000
)
.
astype
(
jnp
.
float32
)
custom_var
=
warmed_up
*
tmp_var
+
(
1.
-
warmed_up
)
*
custom_var
custom_mean
=
warmed_up
*
tmp_mean
+
(
1.
-
warmed_up
)
*
custom_mean
ra_mean
.
value
=
(
self
.
momentum
*
ra_mean
.
value
+
(
1
-
self
.
momentum
)
*
mean
)
ra_var
.
value
=
self
.
momentum
*
ra_var
.
value
+
(
1
-
self
.
momentum
)
*
var
steps
.
value
+=
1
return
_normalize
(
self
,
x
,
custom_mean
,
custom_var
,
reduction_axes
,
feature_axes
,
self
.
dtype
,
self
.
param_dtype
,
self
.
epsilon
,
self
.
use_bias
,
self
.
use_scale
,
self
.
bias_init
,
self
.
scale_init
,
)
ygoai/rl/jax/utils.py
View file @
e1ff8f92
from
typing
import
Any
,
Callable
import
jax
import
jax.numpy
as
jnp
from
flax
import
core
,
struct
from
flax.linen.fp8_ops
import
OVERWRITE_WITH_GRADIENT
import
optax
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
...
...
@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count
=
tot_count
return
new_mean
,
new_var
,
new_count
class
TrainState
(
struct
.
PyTreeNode
):
step
:
int
apply_fn
:
Callable
=
struct
.
field
(
pytree_node
=
False
)
params
:
core
.
FrozenDict
[
str
,
Any
]
=
struct
.
field
(
pytree_node
=
True
)
tx
:
optax
.
GradientTransformation
=
struct
.
field
(
pytree_node
=
False
)
opt_state
:
optax
.
OptState
=
struct
.
field
(
pytree_node
=
True
)
batch_stats
:
core
.
FrozenDict
[
str
,
Any
]
=
struct
.
field
(
pytree_node
=
True
)
def
apply_gradients
(
self
,
*
,
grads
,
**
kwargs
):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if
OVERWRITE_WITH_GRADIENT
in
grads
:
grads_with_opt
=
grads
[
'params'
]
params_with_opt
=
self
.
params
[
'params'
]
else
:
grads_with_opt
=
grads
params_with_opt
=
self
.
params
updates
,
new_opt_state
=
self
.
tx
.
update
(
grads_with_opt
,
self
.
opt_state
,
params_with_opt
)
new_params_with_opt
=
optax
.
apply_updates
(
params_with_opt
,
updates
)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if
OVERWRITE_WITH_GRADIENT
in
grads
:
new_params
=
{
'params'
:
new_params_with_opt
,
OVERWRITE_WITH_GRADIENT
:
grads
[
OVERWRITE_WITH_GRADIENT
],
}
else
:
new_params
=
new_params_with_opt
return
self
.
replace
(
step
=
self
.
step
+
1
,
params
=
new_params
,
opt_state
=
new_opt_state
,
**
kwargs
,
)
@
classmethod
def
create
(
cls
,
*
,
apply_fn
,
params
,
tx
,
**
kwargs
):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt
=
(
params
[
'params'
]
if
OVERWRITE_WITH_GRADIENT
in
params
else
params
)
opt_state
=
tx
.
init
(
params_with_opt
)
return
cls
(
step
=
0
,
apply_fn
=
apply_fn
,
params
=
params
,
tx
=
tx
,
opt_state
=
opt_state
,
**
kwargs
,
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment