Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
662b300f
Commit
662b300f
authored
Jul 10, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update doc and defaults for release
parent
03416f14
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
183 additions
and
244 deletions
+183
-244
README.md
README.md
+67
-62
scripts/battle.py
scripts/battle.py
+3
-3
scripts/cleanba.py
scripts/cleanba.py
+4
-4
scripts/cleanba_g.py
scripts/cleanba_g.py
+2
-2
scripts/cleanba_rnd.py
scripts/cleanba_rnd.py
+2
-2
scripts/eval.py
scripts/eval.py
+1
-1
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+98
-164
ygoai/rl/jax/nnx/agent.py
ygoai/rl/jax/nnx/agent.py
+6
-6
No files found.
README.md
View file @
662b300f
This diff is collapsed.
Click to expand it.
scripts/battle.py
View file @
662b300f
...
...
@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info
=
args
.
oppo_info
,
**
env_option
,
)
envs1
.
num_envs
=
num_envs
envs1
=
EnvPreprocess
(
envs1
,
skip_mask
=
not
args
.
oppo_info
)
if
cross_env
:
...
...
@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2
=
deck2
,
**
env_option
,
)
envs2
.
num_envs
=
num_envs
key
=
jax
.
random
.
PRNGKey
(
seed
)
obs_space1
=
envs1
.
observation_space
envs1
.
num_envs
=
num_envs
envs1
=
RecordEpisodeStatistics
(
envs1
)
sample_obs1
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space1
.
sample
())
agent1
=
create_agent1
(
args
)
...
...
@@ -190,7 +191,6 @@ if __name__ == "__main__":
if
cross_env
:
obs_space2
=
envs2
.
observation_space
envs2
.
num_envs
=
num_envs
envs2
=
RecordEpisodeStatistics
(
envs2
)
sample_obs2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space2
.
sample
())
else
:
...
...
scripts/cleanba.py
View file @
662b300f
...
...
@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma"""
num_minibatches
:
int
=
64
"""the number of mini-batches"""
update_epochs
:
int
=
2
update_epochs
:
int
=
1
"""the K epochs to update the policy"""
switch
:
bool
=
False
"""Toggle the use of switch mechanism"""
...
...
@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages"""
sep_value
:
bool
=
True
"""Whether separate value function computation for each player"""
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"
vtrac
e"
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"
ga
e"
"""the method to learn the value function"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
...
...
@@ -715,14 +715,14 @@ def main():
# seeding
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/cleanba_g.py
View file @
662b300f
...
...
@@ -716,14 +716,14 @@ def main():
# seeding
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/cleanba_rnd.py
View file @
662b300f
...
...
@@ -743,14 +743,14 @@ def main():
# seeding
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
args
.
real_seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
...
...
scripts/eval.py
View file @
662b300f
...
...
@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed
=
random
.
randint
(
0
,
int
(
1e8
)
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
ygoai/rl/jax/agent.py
View file @
662b300f
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/nnx/agent.py
View file @
662b300f
...
...
@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
noam
:
bool
=
Fals
e
noam
:
bool
=
Tru
e
"""whether to use Noam architecture for the transformer layer"""
action_feats
:
bool
=
True
"""whether to use action features for the global state"""
version
:
int
=
0
version
:
int
=
2
"""the version of the environment and the agent"""
...
...
@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
Fals
e
film
:
bool
=
Tru
e
"""whether to use FiLM for the actor"""
rnn_shortcut
:
bool
=
False
"""whether to use shortcut for the RNN"""
...
...
@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history
:
bool
=
True
,
card_mask
:
bool
=
False
,
rnn_type
:
str
=
'lstm'
,
film
:
bool
=
Fals
e
,
noam
:
bool
=
Fals
e
,
film
:
bool
=
Tru
e
,
noam
:
bool
=
Tru
e
,
rwkv_head_size
:
int
=
32
,
action_feats
:
bool
=
True
,
rnn_shortcut
:
bool
=
False
,
batch_norm
:
bool
=
False
,
critic_width
:
int
=
128
,
critic_depth
:
int
=
3
,
version
:
int
=
0
,
version
:
int
=
2
,
q_head
:
bool
=
False
,
switch
:
bool
=
True
,
freeze_id
:
bool
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment