Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
eba0e134
Commit
eba0e134
authored
Feb 21, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change ckpt_dir
parent
25e8c58b
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
385 additions
and
3 deletions
+385
-3
scripts/dmc_dist.py
scripts/dmc_dist.py
+381
-0
scripts/ppo.py
scripts/ppo.py
+4
-3
No files found.
scripts/dmc_dist.py
0 → 100644
View file @
eba0e134
This diff is collapsed.
Click to expand it.
scripts/ppo.py
View file @
eba0e134
...
...
@@ -146,8 +146,6 @@ def run(local_rank, world_size):
if
args
.
world_size
>
1
:
setup
(
args
.
backend
,
local_rank
,
args
.
world_size
,
args
.
port
)
os
.
makedirs
(
args
.
ckpt_dir
,
exist_ok
=
True
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
None
...
...
@@ -159,6 +157,9 @@ def run(local_rank, world_size):
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
ckpt_dir
=
os
.
path
.
join
(
args
.
ckpt_dir
,
run_name
)
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
...
...
@@ -394,7 +395,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
args
.
ckpt_dir
,
f
"ppo_
{iteration}.pth"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"
{iteration}.pth"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment