Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
7edeea56
Commit
7edeea56
authored
Feb 24, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix for nccl timeout
parent
1e253c14
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
17 deletions
+19
-17
scripts/ppo.py
scripts/ppo.py
+17
-16
ygoai/rl/dist.py
ygoai/rl/dist.py
+2
-1
No files found.
scripts/ppo.py
View file @
7edeea56
...
...
@@ -112,11 +112,11 @@ class Args:
"""tensorboard log directory"""
ckpt_dir
:
str
=
"./checkpoints"
"""checkpoint directory"""
save_interval
:
int
=
10
00
save_interval
:
int
=
5
00
"""the number of iterations to save the model"""
log_p
:
float
=
0.1
log_p
:
float
=
1.0
"""the probability of logging"""
port
:
int
=
1235
5
port
:
int
=
1235
6
"""the port to use for distributed training"""
# to be filled in runtime
...
...
@@ -217,8 +217,6 @@ def run(local_rank, world_size):
if
args
.
embedding_file
:
agent
.
load_embeddings
(
embeddings
)
# if args.compile:
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
...
...
@@ -341,18 +339,21 @@ def run(local_rank, world_size):
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
and
random
.
random
()
<
args
.
log_p
:
if
d
:
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
avg_ep_returns
.
append
(
episode_reward
)
winner
=
0
if
episode_reward
>
0
else
1
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
1
-
winner
)
if
random
.
random
()
<
args
.
log_p
:
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
2
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
print
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
len
(
avg_win_rates
)
>
100
:
if
len
(
avg_win_rates
)
>
n
:
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
avg_win_rates
=
[]
...
...
ygoai/rl/dist.py
View file @
7edeea56
...
...
@@ -32,6 +32,7 @@ def setup(backend, rank, world_size, port):
dist
.
all_reduce
(
x
,
op
=
dist
.
ReduceOp
.
SUM
)
x
.
mean
()
.
item
()
dist
.
barrier
()
# print(f"Rank {rank} initialized")
def
mp_start
(
run
):
...
...
@@ -39,7 +40,7 @@ def mp_start(run):
if
world_size
==
1
:
run
(
local_rank
=
0
,
world_size
=
world_size
)
else
:
mp
.
set_start_method
(
'spawn'
)
#
mp.set_start_method('spawn')
children
=
[]
for
i
in
range
(
world_size
):
subproc
=
mp
.
Process
(
target
=
run
,
args
=
(
i
,
world_size
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment