Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
89c209f2
Commit
89c209f2
authored
Apr 24, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add ModelCheckpoint
parent
45506246
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
9 deletions
+68
-9
scripts/ppo.py
scripts/ppo.py
+15
-8
ygoai/rl/ckpt.py
ygoai/rl/ckpt.py
+52
-0
ygoai/rl/jax/switch.py
ygoai/rl/jax/switch.py
+1
-1
No files found.
scripts/ppo.py
View file @
89c209f2
...
@@ -23,6 +23,7 @@ from rich.pretty import pprint
...
@@ -23,6 +23,7 @@ from rich.pretty import pprint
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
...
@@ -45,6 +46,10 @@ class Args:
...
@@ -45,6 +46,10 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
"""the path to the model checkpoint to load"""
checkpoint_dir
:
str
=
"checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket
:
Optional
[
str
]
=
None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
env_id
:
str
=
"YGOPro-v0"
...
@@ -525,6 +530,14 @@ if __name__ == "__main__":
...
@@ -525,6 +530,14 @@ if __name__ == "__main__":
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
)
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
obj
))
ckpt_maneger
=
ModelCheckpoint
(
args
.
checkpoint_dir
,
save_fn
,
n_saved
=
3
,
gcs_bucket
=
args
.
gcs_bucket
)
# seeding
# seeding
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
@@ -854,15 +867,9 @@ if __name__ == "__main__":
...
@@ -854,15 +867,9 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
ckpt_dir
=
f
"checkpoints"
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
model_path
=
os
.
path
.
join
(
ckpt_dir
,
f
"{timestamp}_{M_steps}M.flax_model"
)
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
with
open
(
model_path
,
"wb"
)
as
f
:
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
f
.
write
(
flax
.
serialization
.
to_bytes
(
unreplicated_params
)
)
print
(
f
"model saved to {model_path}"
)
if
learner_policy_version
>=
args
.
num_updates
:
if
learner_policy_version
>=
args
.
num_updates
:
break
break
...
...
ygoai/rl/ckpt.py
0 → 100644
View file @
89c209f2
import
os
from
pathlib
import
Path
class
ModelCheckpoint
(
object
):
""" ModelCheckpoint handler can be used to periodically save objects to disk.
Args:
dirname (str):
Directory path where objects will be saved.
save_fn (callable):
Function that will be called to save the object. It should have the signature `save_fn(obj, path)`.
n_saved (int, optional):
Number of objects that should be kept on disk. Older files will be removed.
gcs_bucket (str, optional):
If provided, will sync the saved model to the specified GCS bucket.
"""
def
__init__
(
self
,
dirname
,
save_fn
,
n_saved
=
1
,
gcs_bucket
=
None
):
self
.
_dirname
=
Path
(
dirname
)
.
expanduser
()
self
.
_n_saved
=
n_saved
self
.
_save_fn
=
save_fn
if
gcs_bucket
.
startswith
(
"gs://"
):
gcs_bucket
=
gcs_bucket
[
5
:]
self
.
_gcs_bucket
=
gcs_bucket
self
.
_saved
=
[]
def
_check_dir
(
self
):
self
.
_dirname
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Ensure that dirname exists
if
not
self
.
_dirname
.
exists
():
raise
ValueError
(
"Directory path '{}' is not found"
.
format
(
self
.
_dirname
))
def
save
(
self
,
obj
,
name
,
sync_gcs
=
True
):
self
.
_check_dir
()
path
=
self
.
_dirname
/
name
self
.
_save_fn
(
obj
,
str
(
path
))
self
.
_saved
.
append
(
path
)
print
(
f
"Saved model to {path}"
)
if
self
.
_gcs_bucket
is
not
None
and
sync_gcs
:
fname
=
"latest"
+
path
.
suffix
gcs_url
=
Path
(
self
.
_gcs_bucket
)
/
fname
gcs_url
=
f
"gs://{gcs_url}"
os
.
system
(
f
"gsutil cp {path} {gcs_url} >> gcs_sync.log 2>&1 &"
)
print
(
"Sync to GCS: "
,
gcs_url
)
if
len
(
self
.
_saved
)
>
self
.
_n_saved
:
path
=
self
.
_saved
.
pop
(
0
)
os
.
remove
(
path
)
ygoai/rl/jax/switch.py
View file @
89c209f2
...
@@ -32,8 +32,8 @@ def truncated_gae_2p0s(
...
@@ -32,8 +32,8 @@ def truncated_gae_2p0s(
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
)
targets
=
values
+
advantages
if
upgo
:
if
upgo
:
advantages
+=
returns
-
values
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
return
targets
,
advantages
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment