Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
330ee6af
Commit
330ee6af
authored
May 30, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add truncated LSTM
parent
cd59a6e9
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
396 additions
and
1073 deletions
+396
-1073
scripts/battle.py
scripts/battle.py
+158
-78
scripts/cleanba.py
scripts/cleanba.py
+198
-49
scripts/eval.py
scripts/eval.py
+40
-30
scripts/ppo.py
scripts/ppo.py
+0
-916
No files found.
scripts/battle.py
View file @
330ee6af
This diff is collapsed.
Click to expand it.
scripts/cleanba.py
View file @
330ee6af
This diff is collapsed.
Click to expand it.
scripts/eval.py
View file @
330ee6af
...
...
@@ -3,7 +3,7 @@ import time
import
os
import
random
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
,
asdict
import
ygoenv
import
numpy
as
np
...
...
@@ -12,6 +12,7 @@ import tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
@
dataclass
...
...
@@ -57,14 +58,8 @@ class Args:
strategy
:
Literal
[
"random"
,
"greedy"
]
=
"greedy"
"""the strategy to use if agent is not used"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent, None for no RNN"""
m
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the model arguments for the agent1"""
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load, must be a `flax_model` file"""
...
...
@@ -78,11 +73,8 @@ class Args:
def
create_agent
(
args
):
return
RNNAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
**
asdict
(
args
.
m
),
embedding_shape
=
args
.
num_embeddings
,
rnn_type
=
args
.
rnn_type
,
)
...
...
@@ -97,12 +89,14 @@ if __name__ == "__main__":
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_fil
e
)
deck
,
deck_names
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
,
return_deck_names
=
Tru
e
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
seed
=
args
.
seed
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
@@ -135,22 +129,21 @@ if __name__ == "__main__":
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent
import
RNNAgent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
=
jax
.
random
.
PRNGKey
(
seed
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_
key
,
sample_obs
,
rstate
)
params
=
jax
.
jit
(
agent
.
init
)(
key
,
sample_obs
,
rstate
)
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
jax
.
device_put
(
params
)
rstate
=
agent
.
init_rnn_state
(
num_envs
)
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
...
...
@@ -180,6 +173,10 @@ if __name__ == "__main__":
start
=
time
.
time
()
start_step
=
step
deck_names
=
sorted
(
deck_names
)
deck_times
=
{
name
:
0
for
name
in
deck_names
}
deck_time_count
=
{
name
:
0
for
name
in
deck_names
}
model_time
=
env_time
=
0
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
...
...
@@ -211,17 +208,30 @@ if __name__ == "__main__":
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
win_reason
=
infos
[
'win_reason'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
win
=
int
(
episode_reward
>
0
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
not
d
:
continue
for
i
in
range
(
2
):
deck_time
=
infos
[
'step_time'
][
idx
][
i
]
deck_name
=
deck_names
[
infos
[
'deck'
][
idx
][
i
]]
time_count
=
deck_time_count
[
deck_name
]
avg_time
=
deck_times
[
deck_name
]
avg_time
=
avg_time
*
(
time_count
/
(
time_count
+
1
))
+
deck_time
/
(
time_count
+
1
)
deck_times
[
deck_name
]
=
avg_time
deck_time_count
[
deck_name
]
+=
1
if
deck_time_count
[
deck_name
]
%
100
==
0
:
print
(
f
"Deck {deck_name}: {avg_time:.4f}"
)
win_reason
=
infos
[
'win_reason'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
win
=
int
(
episode_reward
>
0
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
...
...
scripts/ppo.py
deleted
100644 → 0
View file @
cd59a6e9
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment