Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
6752ba72
Commit
6752ba72
authored
May 16, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add card_mask and film to agent
parent
1d35fed3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
153 additions
and
72 deletions
+153
-72
scripts/battle.py
scripts/battle.py
+9
-26
scripts/cleanba.py
scripts/cleanba.py
+22
-28
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+81
-9
ygoai/rl/jax/transformer.py
ygoai/rl/jax/transformer.py
+41
-9
No files found.
scripts/battle.py
View file @
6752ba72
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
import
os
import
os
import
random
import
random
from
typing
import
Optional
from
typing
import
Optional
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
,
asdict
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
functools
import
partial
from
functools
import
partial
...
@@ -18,7 +18,7 @@ import flax
...
@@ -18,7 +18,7 @@ import flax
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
RNNAgent
from
ygoai.rl.jax.agent2
import
RNNAgent
,
ModelArgs
@
dataclass
@
dataclass
...
@@ -44,10 +44,6 @@ class Args:
...
@@ -44,10 +44,6 @@ class Args:
"""the number of history actions to use"""
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
"""the number of embeddings of the agent"""
use_history1
:
bool
=
True
"""whether to use history actions as input for agent1"""
use_history2
:
bool
=
True
"""whether to use history actions as input for agent2"""
verbose
:
bool
=
False
verbose
:
bool
=
False
"""whether to print debug information"""
"""whether to print debug information"""
...
@@ -59,16 +55,11 @@ class Args:
...
@@ -59,16 +55,11 @@ class Args:
num_envs
:
int
=
64
num_envs
:
int
=
64
"""the number of parallel game environments"""
"""the number of parallel game environments"""
num_layers
:
int
=
2
m1
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the number of layers for the agent"""
"""the model arguments for the agent1"""
num_channels
:
int
=
128
m2
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the number of channels for the agent"""
"""the model arguments for the agent2"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
rnn_type1
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
checkpoint2
:
str
=
"checkpoints/agent.pt"
...
@@ -83,23 +74,15 @@ class Args:
...
@@ -83,23 +74,15 @@ class Args:
def
create_agent1
(
args
):
def
create_agent1
(
args
):
return
RNNAgent
(
return
RNNAgent
(
channels
=
args
.
num_channels
,
**
asdict
(
args
.
m1
),
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history1
,
rnn_type
=
args
.
rnn_type1
,
)
)
def
create_agent2
(
args
):
def
create_agent2
(
args
):
return
RNNAgent
(
return
RNNAgent
(
channels
=
args
.
num_channels
,
**
asdict
(
args
.
m2
),
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history2
,
rnn_type
=
args
.
rnn_type2
,
)
)
...
...
scripts/cleanba.py
View file @
6752ba72
...
@@ -6,7 +6,7 @@ import threading
...
@@ -6,7 +6,7 @@ import threading
import
time
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
,
asdict
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
typing
import
List
,
NamedTuple
,
Optional
from
functools
import
partial
from
functools
import
partial
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
RNNAgent
from
ygoai.rl.jax.agent2
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
...
@@ -80,10 +80,6 @@ class Args:
...
@@ -80,10 +80,6 @@ class Args:
"""the number of history actions to use"""
"""the number of history actions to use"""
greedy_reward
:
bool
=
False
greedy_reward
:
bool
=
False
"""whether to use greedy reward (faster kill higher reward)"""
"""whether to use greedy reward (faster kill higher reward)"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
eval_use_history
:
bool
=
True
"""whether to use history actions as input for eval agent"""
total_timesteps
:
int
=
50000000000
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
...
@@ -146,16 +142,10 @@ class Args:
...
@@ -146,16 +142,10 @@ class Args:
max_grad_norm
:
float
=
1.0
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
num_layers
:
int
=
2
m1
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the number of layers for the agent"""
"""the model arguments for the agent"""
num_channels
:
int
=
128
m2
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the number of channels for the agent"""
"""the model arguments for the eval agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
"""the device ids that actor workers will use"""
...
@@ -228,18 +218,22 @@ class Transition(NamedTuple):
...
@@ -228,18 +218,22 @@ class Transition(NamedTuple):
def
create_agent
(
args
,
eval
=
False
):
def
create_agent
(
args
,
eval
=
False
):
return
RNNAgent
(
if
eval
:
channels
=
args
.
num_channels
,
return
RNNAgent
(
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
**
asdict
(
args
.
m2
),
rnn_channels
=
args
.
rnn_channels
,
)
switch
=
args
.
switch
,
else
:
freeze_id
=
args
.
freeze_id
,
return
RNNAgent
(
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
embedding_shape
=
args
.
num_embeddings
,
rnn_type
=
args
.
rnn_type
if
not
eval
else
args
.
eval_rnn_type
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
)
param_dtype
=
jnp
.
float32
,
switch
=
args
.
switch
,
freeze_id
=
args
.
freeze_id
,
**
asdict
(
args
.
m1
),
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
def
init_rnn_state
(
num_envs
,
rnn_channels
):
...
...
ygoai/rl/jax/agent2.py
View file @
6752ba72
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
...
@@ -6,7 +7,7 @@ import jax
...
@@ -6,7 +7,7 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax.linen
as
nn
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
...
@@ -15,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
...
@@ -15,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoder
(
nn
.
Module
):
class
ActionEncoder
(
nn
.
Module
):
channels
:
int
=
128
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
...
@@ -69,6 +77,8 @@ class CardEncoder(nn.Module):
...
@@ -69,6 +77,8 @@ class CardEncoder(nn.Module):
x_id
=
layer_norm
()(
x_id
)
x_id
=
layer_norm
()(
x_id
)
x_loc
=
x1
[:,
:,
0
]
x_loc
=
x1
[:,
:,
0
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
x_seq
=
x1
[:,
:,
1
]
x_seq
=
x1
[:,
:,
1
]
...
@@ -97,7 +107,7 @@ class CardEncoder(nn.Module):
...
@@ -97,7 +107,7 @@ class CardEncoder(nn.Module):
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards
=
f_cards
+
f_loc
+
f_seq
return
f_cards
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
class
GlobalEncoder
(
nn
.
Module
):
...
@@ -153,6 +163,8 @@ class Encoder(nn.Module):
...
@@ -153,6 +163,8 @@ class Encoder(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
...
@@ -188,7 +200,7 @@ class Encoder(nn.Module):
...
@@ -188,7 +200,7 @@ class Encoder(nn.Module):
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
# Cards
# Cards
f_cards
=
CardEncoder
(
f_cards
,
c_mask
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_id
,
x_cards
[:,
:,
2
:])
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_id
,
x_cards
[:,
:,
2
:])
g_card_embed
=
self
.
param
(
g_card_embed
=
self
.
param
(
'g_card_embed'
,
'g_card_embed'
,
...
@@ -196,10 +208,16 @@ class Encoder(nn.Module):
...
@@ -196,10 +208,16 @@ class Encoder(nn.Module):
(
1
,
c
),
self
.
param_dtype
)
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
if
self
.
card_mask
:
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
c_mask
=
None
num_heads
=
max
(
2
,
c
//
128
)
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_layers
):
for
_
in
range
(
self
.
num_layers
):
f_cards
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
)
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
f_g_card
=
f_cards
[:,
0
]
...
@@ -294,6 +312,32 @@ class Actor(nn.Module):
...
@@ -294,6 +312,32 @@ class Actor(nn.Module):
return
logits
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
Critic
(
nn
.
Module
):
class
Critic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
...
@@ -340,10 +384,29 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
...
@@ -340,10 +384,29 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
return
rstate
,
f_state
return
rstate
,
f_state
@
dataclass
class
ModelArgs
:
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
class
RNNAgent
(
nn
.
Module
):
class
RNNAgent
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
num_layers
:
int
=
2
num_channels
:
int
=
128
rnn_channels
:
int
=
512
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
...
@@ -351,11 +414,14 @@ class RNNAgent(nn.Module):
...
@@ -351,11 +414,14 @@ class RNNAgent(nn.Module):
switch
:
bool
=
True
switch
:
bool
=
True
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
channels
c
=
self
.
num_
channels
encoder
=
Encoder
(
encoder
=
Encoder
(
channels
=
c
,
channels
=
c
,
num_layers
=
self
.
num_layers
,
num_layers
=
self
.
num_layers
,
...
@@ -364,6 +430,8 @@ class RNNAgent(nn.Module):
...
@@ -364,6 +430,8 @@ class RNNAgent(nn.Module):
param_dtype
=
self
.
param_dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
)
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
...
@@ -401,8 +469,12 @@ class RNNAgent(nn.Module):
...
@@ -401,8 +469,12 @@ class RNNAgent(nn.Module):
rstate
,
f_state_r
=
rnn_step_by_main
(
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
actor
=
Actor
(
if
self
.
film
:
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
...
ygoai/rl/jax/transformer.py
View file @
6752ba72
...
@@ -596,7 +596,6 @@ class GLUMlpBlock(nn.Module):
...
@@ -596,7 +596,6 @@ class GLUMlpBlock(nn.Module):
param_dtype
=
self
.
param_dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
self
.
kernel_init
,
kernel_init
=
self
.
kernel_init
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
shard
=
self
.
shard
,
)
for
_
in
range
(
3
)
)
for
_
in
range
(
3
)
]
]
...
@@ -631,7 +630,10 @@ class EncoderLayer(nn.Module):
...
@@ -631,7 +630,10 @@ class EncoderLayer(nn.Module):
deterministic
:
bool
=
True
deterministic
:
bool
=
True
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
,
src_key_padding_mask
=
None
):
def
__call__
(
self
,
inputs
,
src_key_padding_mask
=
None
,
attn_scale
=
None
,
attn_bias
=
None
,
output_scale
=
None
,
output_bias
=
None
):
inputs
=
jnp
.
asarray
(
inputs
,
self
.
dtype
)
inputs
=
jnp
.
asarray
(
inputs
,
self
.
dtype
)
x
=
nn
.
LayerNorm
(
epsilon
=
self
.
layer_norm_epsilon
,
x
=
nn
.
LayerNorm
(
epsilon
=
self
.
layer_norm_epsilon
,
dtype
=
self
.
dtype
,
name
=
"ln_1"
)(
inputs
)
dtype
=
self
.
dtype
,
name
=
"ln_1"
)(
inputs
)
...
@@ -648,6 +650,11 @@ class EncoderLayer(nn.Module):
...
@@ -648,6 +650,11 @@ class EncoderLayer(nn.Module):
x
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
x
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
x
,
deterministic
=
self
.
deterministic
)
x
,
deterministic
=
self
.
deterministic
)
if
attn_scale
is
not
None
:
x
=
x
*
attn_scale
if
attn_bias
is
not
None
:
x
=
x
+
attn_bias
x
=
x
+
inputs
x
=
x
+
inputs
y
=
nn
.
LayerNorm
(
epsilon
=
self
.
layer_norm_epsilon
,
y
=
nn
.
LayerNorm
(
epsilon
=
self
.
layer_norm_epsilon
,
...
@@ -662,7 +669,13 @@ class EncoderLayer(nn.Module):
...
@@ -662,7 +669,13 @@ class EncoderLayer(nn.Module):
name
=
"mlp"
)(
y
)
name
=
"mlp"
)(
y
)
y
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
y
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
y
,
deterministic
=
self
.
deterministic
)
y
,
deterministic
=
self
.
deterministic
)
if
output_scale
is
not
None
:
y
=
y
*
output_scale
if
output_bias
is
not
None
:
y
=
y
+
output_bias
y
=
x
+
y
y
=
x
+
y
return
y
return
y
...
@@ -733,8 +746,9 @@ class DecoderLayer(nn.Module):
...
@@ -733,8 +746,9 @@ class DecoderLayer(nn.Module):
class
LlamaEncoderLayer
(
nn
.
Module
):
class
LlamaEncoderLayer
(
nn
.
Module
):
n_heads
:
int
n_heads
:
int
intermediate_size
:
int
intermediate_size
:
Optional
[
int
]
=
None
n_positions
:
int
=
512
n_positions
:
int
=
512
rope
:
bool
=
True
dtype
:
Any
=
None
dtype
:
Any
=
None
param_dtype
:
Any
=
jnp
.
float32
param_dtype
:
Any
=
jnp
.
float32
attn_pdrop
:
float
=
0.0
attn_pdrop
:
float
=
0.0
...
@@ -745,11 +759,17 @@ class LlamaEncoderLayer(nn.Module):
...
@@ -745,11 +759,17 @@ class LlamaEncoderLayer(nn.Module):
deterministic
:
bool
=
True
deterministic
:
bool
=
True
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
,
src_key_padding_mask
=
None
):
def
__call__
(
self
,
inputs
,
src_key_padding_mask
=
None
,
attn_scale
=
None
,
attn_bias
=
None
,
output_scale
=
None
,
output_bias
=
None
):
features
=
inputs
.
shape
[
-
1
]
intermediate_size
=
self
.
intermediate_size
or
2
*
features
x
=
RMSNorm
(
epsilon
=
self
.
rms_norm_eps
,
x
=
RMSNorm
(
epsilon
=
self
.
rms_norm_eps
,
dtype
=
self
.
dtype
,
name
=
"ln_1"
)(
inputs
)
dtype
=
self
.
dtype
,
name
=
"ln_1"
)(
inputs
)
x
=
MultiheadAttention
(
x
=
MultiheadAttention
(
features
=
x
.
shape
[
-
1
]
,
features
=
features
,
num_heads
=
self
.
n_heads
,
num_heads
=
self
.
n_heads
,
max_len
=
self
.
n_positions
,
max_len
=
self
.
n_positions
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -757,19 +777,24 @@ class LlamaEncoderLayer(nn.Module):
...
@@ -757,19 +777,24 @@ class LlamaEncoderLayer(nn.Module):
kernel_init
=
self
.
kernel_init
,
kernel_init
=
self
.
kernel_init
,
qkv_bias
=
False
,
qkv_bias
=
False
,
out_bias
=
False
,
out_bias
=
False
,
rope
=
Tru
e
,
rope
=
self
.
rop
e
,
dropout_rate
=
self
.
attn_pdrop
,
dropout_rate
=
self
.
attn_pdrop
,
deterministic
=
self
.
deterministic
,
deterministic
=
self
.
deterministic
,
name
=
"attn"
)(
x
,
x
,
x
,
key_padding_mask
=
src_key_padding_mask
)
name
=
"attn"
)(
x
,
x
,
x
,
key_padding_mask
=
src_key_padding_mask
)
x
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
x
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
x
,
deterministic
=
self
.
deterministic
)
x
,
deterministic
=
self
.
deterministic
)
if
attn_scale
is
not
None
:
x
=
x
*
attn_scale
if
attn_bias
is
not
None
:
x
=
x
+
attn_bias
x
=
x
+
inputs
x
=
x
+
inputs
y
=
RMSNorm
(
epsilon
=
self
.
rms_norm_eps
,
y
=
RMSNorm
(
epsilon
=
self
.
rms_norm_eps
,
dtype
=
self
.
dtype
,
name
=
"ln_2"
)(
x
)
dtype
=
self
.
dtype
,
name
=
"ln_2"
)(
x
)
y
=
GLUMlpBlock
(
y
=
GLUMlpBlock
(
intermediate_size
=
self
.
intermediate_size
,
intermediate_size
=
intermediate_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
self
.
kernel_init
,
kernel_init
=
self
.
kernel_init
,
...
@@ -777,6 +802,12 @@ class LlamaEncoderLayer(nn.Module):
...
@@ -777,6 +802,12 @@ class LlamaEncoderLayer(nn.Module):
name
=
"mlp"
)(
y
)
name
=
"mlp"
)(
y
)
y
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
y
=
nn
.
Dropout
(
rate
=
self
.
resid_pdrop
)(
y
,
deterministic
=
self
.
deterministic
)
y
,
deterministic
=
self
.
deterministic
)
if
output_scale
is
not
None
:
y
=
y
*
output_scale
if
output_bias
is
not
None
:
y
=
y
+
output_bias
y
=
x
+
y
y
=
x
+
y
return
y
return
y
...
@@ -785,6 +816,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -785,6 +816,7 @@ class LlamaDecoderLayer(nn.Module):
n_heads
:
int
n_heads
:
int
intermediate_size
:
int
intermediate_size
:
int
n_positions
:
int
=
512
n_positions
:
int
=
512
rope
:
bool
=
True
dtype
:
Any
=
None
dtype
:
Any
=
None
param_dtype
:
Any
=
jnp
.
float32
param_dtype
:
Any
=
jnp
.
float32
attn_pdrop
:
float
=
0.0
attn_pdrop
:
float
=
0.0
...
@@ -808,7 +840,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -808,7 +840,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init
=
self
.
kernel_init
,
kernel_init
=
self
.
kernel_init
,
qkv_bias
=
False
,
qkv_bias
=
False
,
out_bias
=
False
,
out_bias
=
False
,
rope
=
Tru
e
,
rope
=
self
.
rop
e
,
dropout_rate
=
self
.
attn_pdrop
,
dropout_rate
=
self
.
attn_pdrop
,
deterministic
=
self
.
deterministic
,
deterministic
=
self
.
deterministic
,
name
=
"self_attn"
)(
x
,
x
,
x
,
key_padding_mask
=
tgt_key_padding_mask
)
name
=
"self_attn"
)(
x
,
x
,
x
,
key_padding_mask
=
tgt_key_padding_mask
)
...
@@ -827,7 +859,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -827,7 +859,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init
=
self
.
kernel_init
,
kernel_init
=
self
.
kernel_init
,
qkv_bias
=
False
,
qkv_bias
=
False
,
out_bias
=
False
,
out_bias
=
False
,
rope
=
Tru
e
,
rope
=
self
.
rop
e
,
dropout_rate
=
self
.
attn_pdrop
,
dropout_rate
=
self
.
attn_pdrop
,
deterministic
=
self
.
deterministic
,
deterministic
=
self
.
deterministic
,
name
=
"cross_attn"
)(
y
,
memory
,
memory
,
key_padding_mask
=
memory_key_padding_mask
)
name
=
"cross_attn"
)(
y
,
memory
,
memory
,
key_padding_mask
=
memory_key_padding_mask
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment