Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
6af7c4b4
Commit
6af7c4b4
authored
Apr 28, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Tensorboard log only on local_rank 0
parent
45885f81
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
15 deletions
+17
-15
scripts/ppo.py
scripts/ppo.py
+16
-9
ygoai/rl/ckpt.py
ygoai/rl/ckpt.py
+1
-6
No files found.
scripts/ppo.py
View file @
6af7c4b4
import
os
import
shutil
import
queue
import
random
import
threading
...
...
@@ -527,13 +528,18 @@ if __name__ == "__main__":
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
else
:
writer
=
dummy_writer
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
...
...
@@ -782,8 +788,6 @@ if __name__ == "__main__":
params_queues
=
[]
rollout_queues
=
[]
eval_queue
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
...
...
@@ -875,8 +879,11 @@ if __name__ == "__main__":
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
lastest_path
=
ckpt_maneger
.
get_latest
()
copy_path
=
lastest_path
.
with_name
(
"latest"
+
lastest_path
.
suffix
)
shutil
.
copyfile
(
lastest_path
,
copy_path
)
zip_file_path
=
"latest.zip"
zip_files
(
zip_file_path
,
[
ckpt_maneger
.
get_latest
(
),
tb_log_dir
])
zip_files
(
zip_file_path
,
[
str
(
copy_path
),
tb_log_dir
])
sync_to_gcs
(
args
.
gcs_bucket
,
zip_file_path
)
if
learner_policy_version
>=
args
.
num_updates
:
...
...
ygoai/rl/ckpt.py
View file @
6af7c4b4
import
os
import
shutil
from
pathlib
import
Path
import
zipfile
...
...
@@ -37,17 +36,13 @@ class ModelCheckpoint(object):
self
.
_saved
.
append
(
path
)
print
(
f
"Saved model to {path}"
)
# Copy the lastest checkpoint as latest
lastest_path
=
path
.
with_name
(
"latest"
+
path
.
suffix
)
shutil
.
copyfile
(
path
,
lastest_path
)
if
len
(
self
.
_saved
)
>
self
.
_n_saved
:
path
=
self
.
_saved
.
pop
(
0
)
os
.
remove
(
path
)
def
get_latest
(
self
):
path
=
self
.
_saved
[
-
1
]
return
str
(
path
.
with_name
(
"latest"
+
path
.
suffix
))
return
path
def
sync_to_gcs
(
bucket
,
source
,
dest
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment