Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
662b300f
Commit
662b300f
authored
Jul 10, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update doc and defaults for release
parent
03416f14
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
183 additions
and
244 deletions
+183
-244
README.md
README.md
+67
-62
scripts/battle.py
scripts/battle.py
+3
-3
scripts/cleanba.py
scripts/cleanba.py
+4
-4
scripts/cleanba_g.py
scripts/cleanba_g.py
+2
-2
scripts/cleanba_rnd.py
scripts/cleanba_rnd.py
+2
-2
scripts/eval.py
scripts/eval.py
+1
-1
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+98
-164
ygoai/rl/jax/nnx/agent.py
ygoai/rl/jax/nnx/agent.py
+6
-6
No files found.
README.md
View file @
662b300f
This diff is collapsed.
Click to expand it.
scripts/battle.py
View file @
662b300f
...
@@ -131,7 +131,7 @@ if __name__ == "__main__":
...
@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed
=
args
.
seed
+
100000
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
@@ -165,6 +165,7 @@ if __name__ == "__main__":
...
@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info
=
args
.
oppo_info
,
oppo_info
=
args
.
oppo_info
,
**
env_option
,
**
env_option
,
)
)
envs1
.
num_envs
=
num_envs
envs1
=
EnvPreprocess
(
envs1
,
skip_mask
=
not
args
.
oppo_info
)
envs1
=
EnvPreprocess
(
envs1
,
skip_mask
=
not
args
.
oppo_info
)
if
cross_env
:
if
cross_env
:
...
@@ -175,11 +176,11 @@ if __name__ == "__main__":
...
@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2
=
deck2
,
deck2
=
deck2
,
**
env_option
,
**
env_option
,
)
)
envs2
.
num_envs
=
num_envs
key
=
jax
.
random
.
PRNGKey
(
seed
)
key
=
jax
.
random
.
PRNGKey
(
seed
)
obs_space1
=
envs1
.
observation_space
obs_space1
=
envs1
.
observation_space
envs1
.
num_envs
=
num_envs
envs1
=
RecordEpisodeStatistics
(
envs1
)
envs1
=
RecordEpisodeStatistics
(
envs1
)
sample_obs1
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space1
.
sample
())
sample_obs1
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space1
.
sample
())
agent1
=
create_agent1
(
args
)
agent1
=
create_agent1
(
args
)
...
@@ -190,7 +191,6 @@ if __name__ == "__main__":
...
@@ -190,7 +191,6 @@ if __name__ == "__main__":
if
cross_env
:
if
cross_env
:
obs_space2
=
envs2
.
observation_space
obs_space2
=
envs2
.
observation_space
envs2
.
num_envs
=
num_envs
envs2
=
RecordEpisodeStatistics
(
envs2
)
envs2
=
RecordEpisodeStatistics
(
envs2
)
sample_obs2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space2
.
sample
())
sample_obs2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space2
.
sample
())
else
:
else
:
...
...
scripts/cleanba.py
View file @
662b300f
...
@@ -106,7 +106,7 @@ class Args:
...
@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma"""
"""the discount factor gamma"""
num_minibatches
:
int
=
64
num_minibatches
:
int
=
64
"""the number of mini-batches"""
"""the number of mini-batches"""
update_epochs
:
int
=
2
update_epochs
:
int
=
1
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
switch
:
bool
=
False
switch
:
bool
=
False
"""Toggle the use of switch mechanism"""
"""Toggle the use of switch mechanism"""
...
@@ -119,7 +119,7 @@ class Args:
...
@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages"""
"""Toggle the use of UPGO for advantages"""
sep_value
:
bool
=
True
sep_value
:
bool
=
True
"""Whether separate value function computation for each player"""
"""Whether separate value function computation for each player"""
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"
vtrac
e"
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"
ga
e"
"""the method to learn the value function"""
"""the method to learn the value function"""
gae_lambda
:
float
=
0.95
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
...
@@ -715,14 +715,14 @@ def main():
...
@@ -715,14 +715,14 @@ def main():
# seeding
# seeding
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/cleanba_g.py
View file @
662b300f
...
@@ -716,14 +716,14 @@ def main():
...
@@ -716,14 +716,14 @@ def main():
# seeding
# seeding
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/cleanba_rnd.py
View file @
662b300f
...
@@ -743,14 +743,14 @@ def main():
...
@@ -743,14 +743,14 @@ def main():
# seeding
# seeding
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/eval.py
View file @
662b300f
...
@@ -96,7 +96,7 @@ if __name__ == "__main__":
...
@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed
=
args
.
seed
+
100000
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
ygoai/rl/jax/agent.py
View file @
662b300f
...
@@ -83,11 +83,10 @@ class CardEncoder(nn.Module):
...
@@ -83,11 +83,10 @@ class CardEncoder(nn.Module):
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
version
:
int
=
0
version
:
int
=
2
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
,
mask
):
def
__call__
(
self
,
x_id
,
x
,
mask
):
assert
self
.
version
>
0
c
=
self
.
channels
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
,
dtype
=
self
.
dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
,
dtype
=
self
.
dtype
)
...
@@ -105,13 +104,6 @@ class CardEncoder(nn.Module):
...
@@ -105,13 +104,6 @@ class CardEncoder(nn.Module):
x_loc
=
x1
[:,
:,
0
]
x_loc
=
x1
[:,
:,
0
]
x_seq
=
x1
[:,
:,
1
]
x_seq
=
x1
[:,
:,
1
]
if
self
.
version
==
0
:
x_id
=
mlp
(
(
c
,
c
//
4
),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
c_mask
=
x_loc
==
0
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
...
@@ -130,16 +122,6 @@ class CardEncoder(nn.Module):
...
@@ -130,16 +122,6 @@ class CardEncoder(nn.Module):
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
if
self
.
version
==
0
:
x_f
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards_g
=
None
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
...
@@ -175,7 +157,7 @@ class GlobalEncoder(nn.Module):
...
@@ -175,7 +157,7 @@ class GlobalEncoder(nn.Module):
channels
:
int
=
128
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
version
:
int
=
2
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
...
@@ -230,7 +212,7 @@ class Encoder(nn.Module):
...
@@ -230,7 +212,7 @@ class Encoder(nn.Module):
noam
:
bool
=
False
noam
:
bool
=
False
action_feats
:
bool
=
True
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
version
:
int
=
0
version
:
int
=
2
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
...
@@ -252,7 +234,7 @@ class Encoder(nn.Module):
...
@@ -252,7 +234,7 @@ class Encoder(nn.Module):
card_encoder
=
CardEncoder
(
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoder
V1
ActionEncoderCls
=
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
@@ -313,33 +295,6 @@ class Encoder(nn.Module):
...
@@ -313,33 +295,6 @@ class Encoder(nn.Module):
# History actions
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
if
self
.
version
==
0
:
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
()(
f_h_actions
[:,
0
])
else
:
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
...
@@ -379,28 +334,6 @@ class Encoder(nn.Module):
...
@@ -379,28 +334,6 @@ class Encoder(nn.Module):
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
if
self
.
version
==
0
:
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
)(
f_a_cards
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
if
not
self
.
use_history
:
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
else
:
spec_index
=
x_actions
[
...
,
0
]
spec_index
=
x_actions
[
...
,
0
]
B
=
jnp
.
arange
(
batch_size
)
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
...
@@ -436,6 +369,7 @@ class Encoder(nn.Module):
...
@@ -436,6 +369,7 @@ class Encoder(nn.Module):
g_feats
.
append
(
f_g_actions
)
g_feats
.
append
(
f_g_actions
)
f_state
=
jnp
.
concatenate
(
g_feats
,
axis
=-
1
)
f_state
=
jnp
.
concatenate
(
g_feats
,
axis
=-
1
)
oc
=
self
.
out_channels
or
c
oc
=
self
.
out_channels
or
c
if
self
.
version
==
2
:
if
self
.
version
==
2
:
f_state
=
GLUMlp
(
f_state
=
GLUMlp
(
...
@@ -573,7 +507,7 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False)
...
@@ -573,7 +507,7 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False)
return
rstate
,
f_state
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
Tru
e
,
return_state
=
False
):
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
Fals
e
,
return_state
=
False
):
if
switch
:
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
init_rstate2
=
carry
...
@@ -601,11 +535,11 @@ class EncoderArgs:
...
@@ -601,11 +535,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
"""whether to mask the padding card as ignored in the transformer"""
noam
:
bool
=
Fals
e
noam
:
bool
=
Tru
e
"""whether to use Noam architecture for the transformer layer"""
"""whether to use Noam architecture for the transformer layer"""
action_feats
:
bool
=
True
action_feats
:
bool
=
True
"""whether to use action features for the global state"""
"""whether to use action features for the global state"""
version
:
int
=
0
version
:
int
=
2
"""the version of the environment and the agent"""
"""the version of the environment and the agent"""
...
@@ -615,7 +549,7 @@ class ModelArgs(EncoderArgs):
...
@@ -615,7 +549,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
Fals
e
film
:
bool
=
Tru
e
"""whether to use FiLM for the actor"""
"""whether to use FiLM for the actor"""
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
"""whether to use opponent's information"""
"""whether to use opponent's information"""
...
@@ -638,8 +572,8 @@ class RNNAgent(nn.Module):
...
@@ -638,8 +572,8 @@ class RNNAgent(nn.Module):
use_history
:
bool
=
True
use_history
:
bool
=
True
card_mask
:
bool
=
False
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
rnn_type
:
str
=
'lstm'
film
:
bool
=
Fals
e
film
:
bool
=
Tru
e
noam
:
bool
=
Fals
e
noam
:
bool
=
Tru
e
rwkv_head_size
:
int
=
32
rwkv_head_size
:
int
=
32
action_feats
:
bool
=
True
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
oppo_info
:
bool
=
False
...
@@ -647,10 +581,10 @@ class RNNAgent(nn.Module):
...
@@ -647,10 +581,10 @@ class RNNAgent(nn.Module):
batch_norm
:
bool
=
False
batch_norm
:
bool
=
False
critic_width
:
int
=
128
critic_width
:
int
=
128
critic_depth
:
int
=
3
critic_depth
:
int
=
3
version
:
int
=
0
version
:
int
=
2
q_head
:
bool
=
False
q_head
:
bool
=
False
switch
:
bool
=
Tru
e
switch
:
bool
=
Fals
e
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
int_head
:
bool
=
False
int_head
:
bool
=
False
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
...
...
ygoai/rl/jax/nnx/agent.py
View file @
662b300f
...
@@ -646,11 +646,11 @@ class EncoderArgs:
...
@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
"""whether to mask the padding card as ignored in the transformer"""
noam
:
bool
=
Fals
e
noam
:
bool
=
Tru
e
"""whether to use Noam architecture for the transformer layer"""
"""whether to use Noam architecture for the transformer layer"""
action_feats
:
bool
=
True
action_feats
:
bool
=
True
"""whether to use action features for the global state"""
"""whether to use action features for the global state"""
version
:
int
=
0
version
:
int
=
2
"""the version of the environment and the agent"""
"""the version of the environment and the agent"""
...
@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
...
@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
Fals
e
film
:
bool
=
Tru
e
"""whether to use FiLM for the actor"""
"""whether to use FiLM for the actor"""
rnn_shortcut
:
bool
=
False
rnn_shortcut
:
bool
=
False
"""whether to use shortcut for the RNN"""
"""whether to use shortcut for the RNN"""
...
@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
...
@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history
:
bool
=
True
,
use_history
:
bool
=
True
,
card_mask
:
bool
=
False
,
card_mask
:
bool
=
False
,
rnn_type
:
str
=
'lstm'
,
rnn_type
:
str
=
'lstm'
,
film
:
bool
=
Fals
e
,
film
:
bool
=
Tru
e
,
noam
:
bool
=
Fals
e
,
noam
:
bool
=
Tru
e
,
rwkv_head_size
:
int
=
32
,
rwkv_head_size
:
int
=
32
,
action_feats
:
bool
=
True
,
action_feats
:
bool
=
True
,
rnn_shortcut
:
bool
=
False
,
rnn_shortcut
:
bool
=
False
,
batch_norm
:
bool
=
False
,
batch_norm
:
bool
=
False
,
critic_width
:
int
=
128
,
critic_width
:
int
=
128
,
critic_depth
:
int
=
3
,
critic_depth
:
int
=
3
,
version
:
int
=
0
,
version
:
int
=
2
,
q_head
:
bool
=
False
,
q_head
:
bool
=
False
,
switch
:
bool
=
True
,
switch
:
bool
=
True
,
freeze_id
:
bool
=
False
,
freeze_id
:
bool
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment