Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ad9980e5
Commit
ad9980e5
authored
Mar 07, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add battle
parent
e5e5402a
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
394 additions
and
715 deletions
+394
-715
repo/packages/e/edopro-core/edopro-core
repo/packages/e/edopro-core/edopro-core
+1
-0
scripts/battle.py
scripts/battle.py
+31
-27
scripts/eval.py
scripts/eval.py
+22
-0
scripts/ppo.py
scripts/ppo.py
+1
-1
scripts/ppo_sp.py
scripts/ppo_sp.py
+33
-25
scripts/ppo_sp2.py
scripts/ppo_sp2.py
+0
-658
scripts/ppo_sp5.py
scripts/ppo_sp5.py
+3
-3
ygoai/rl/agent.py
ygoai/rl/agent.py
+303
-1
No files found.
edopro-core
@
8c623744
Subproject commit 8c6237444e294b730bce1eccc6fab2721b7cbea9
scripts/battle.py
View file @
ad9980e5
...
@@ -65,7 +65,7 @@ class Args:
...
@@ -65,7 +65,7 @@ class Args:
checkpoint2
:
Optional
[
str
]
=
"checkpoints/agent.pt"
checkpoint2
:
Optional
[
str
]
=
"checkpoints/agent.pt"
"""the checkpoint to load for the second agent"""
"""the checkpoint to load for the second agent"""
compile
:
bool
=
Tru
e
compile
:
bool
=
Fals
e
"""if toggled, the model will be compiled"""
"""if toggled, the model will be compiled"""
optimize
:
bool
=
False
optimize
:
bool
=
False
"""if toggled, the model will be optimized"""
"""if toggled, the model will be optimized"""
...
@@ -130,33 +130,37 @@ if __name__ == "__main__":
...
@@ -130,33 +130,37 @@ if __name__ == "__main__":
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
embedding_shape
=
args
.
num_embeddings
if
args
.
checkpoint1
.
endswith
(
".ptj"
):
if
embedding_shape
is
None
:
agent1
=
torch
.
jit
.
load
(
args
.
checkpoint1
)
with
open
(
args
.
code_list_file
,
"r"
)
as
f
:
agent2
=
torch
.
jit
.
load
(
args
.
checkpoint2
)
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
if
not
args
.
compile
:
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
'reduce-overhead'
)
else
:
else
:
if
args
.
optimize
:
embedding_shape
=
args
.
num_embeddings
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
if
embedding_shape
is
None
:
def
optimize_for_inference
(
agent
):
with
open
(
args
.
code_list_file
,
"r"
)
as
f
:
with
torch
.
no_grad
():
code_list
=
f
.
readlines
()
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
embedding_shape
=
len
(
code_list
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
L
=
args
.
num_layers
agent1
=
optimize_for_inference
(
agent1
)
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent2
=
optimize_for_inference
(
agent2
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
if
not
args
.
compile
:
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
'reduce-overhead'
)
else
:
if
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
def
optimize_for_inference
(
agent
):
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent1
=
optimize_for_inference
(
agent1
)
agent2
=
optimize_for_inference
(
agent2
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play_
=
infos
[
'to_play'
]
next_to_play_
=
infos
[
'to_play'
]
...
...
scripts/eval.py
View file @
ad9980e5
...
@@ -80,6 +80,9 @@ class Args:
...
@@ -80,6 +80,9 @@ class Args:
"""if toggled, the model will be compiled"""
"""if toggled, the model will be compiled"""
optimize
:
bool
=
True
optimize
:
bool
=
True
"""if toggled, the model will be optimized"""
"""if toggled, the model will be optimized"""
convert
:
bool
=
False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads
:
Optional
[
int
]
=
None
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
env_threads
:
Optional
[
int
]
=
16
...
@@ -156,6 +159,21 @@ if __name__ == "__main__":
...
@@ -156,6 +159,21 @@ if __name__ == "__main__":
print
(
agent
.
load_state_dict
(
state_dict
))
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
if
args
.
compile
:
if
args
.
convert
:
# Don't support dynamic shapes and very slow inference
raise
NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
elif
args
.
optimize
:
elif
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
...
@@ -164,6 +182,10 @@ if __name__ == "__main__":
...
@@ -164,6 +182,10 @@ if __name__ == "__main__":
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent
=
optimize_for_inference
(
agent
)
agent
=
optimize_for_inference
(
agent
)
if
args
.
convert
:
torch
.
jit
.
save
(
agent
,
args
.
checkpoint
+
"j"
)
print
(
f
"Optimized model saved to {args.checkpoint}j"
)
exit
(
0
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
...
...
scripts/ppo.py
View file @
ad9980e5
...
@@ -425,7 +425,7 @@ def run(local_rank, world_size):
...
@@ -425,7 +425,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
if
local_rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
or
iteration
==
args
.
num_iterations
:
if
iteration
%
args
.
save_interval
==
0
or
iteration
==
args
.
num_iterations
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt
h
"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
...
scripts/ppo_sp.py
View file @
ad9980e5
...
@@ -21,7 +21,7 @@ from torch.cuda.amp import GradScaler, autocast
...
@@ -21,7 +21,7 @@ from torch.cuda.amp import GradScaler, autocast
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
mp_start
,
setup
,
fprint
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_
setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.buffer
import
create_obs
...
@@ -118,8 +118,6 @@ class Args:
...
@@ -118,8 +118,6 @@ class Args:
"""the number of iterations to save the model"""
"""the number of iterations to save the model"""
log_p
:
float
=
1.0
log_p
:
float
=
1.0
"""the probability of logging"""
"""the probability of logging"""
port
:
int
=
12356
"""the port to use for distributed training"""
eval_episodes
:
int
=
128
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
10
eval_interval
:
int
=
10
...
@@ -140,7 +138,12 @@ class Args:
...
@@ -140,7 +138,12 @@ class Args:
"""the number of processes (computed in runtime)"""
"""the number of processes (computed in runtime)"""
def
run
(
local_rank
,
world_size
):
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
print
(
f
"rank={rank}, local_rank={local_rank}, world_size={world_size}"
)
args
=
tyro
.
cli
(
Args
)
args
=
tyro
.
cli
(
Args
)
args
.
world_size
=
world_size
args
.
world_size
=
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
...
@@ -158,12 +161,12 @@ def run(local_rank, world_size):
...
@@ -158,12 +161,12 @@ def run(local_rank, world_size):
torch
.
set_float32_matmul_precision
(
'high'
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
if
args
.
world_size
>
1
:
setup
(
args
.
backend
,
local_rank
,
args
.
world_size
,
args
.
port
)
torchrun_setup
(
args
.
backend
,
local_rank
)
timestamp
=
int
(
time
.
time
())
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
None
writer
=
None
if
local_
rank
==
0
:
if
rank
==
0
:
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
writer
.
add_text
(
writer
.
add_text
(
...
@@ -177,10 +180,10 @@ def run(local_rank, world_size):
...
@@ -177,10 +180,10 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args
.
seed
+=
local_
rank
args
.
seed
+=
rank
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
-
local_
rank
)
torch
.
manual_seed
(
args
.
seed
-
rank
)
if
args
.
torch_deterministic
:
if
args
.
torch_deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
else
:
else
:
...
@@ -188,7 +191,7 @@ def run(local_rank, world_size):
...
@@ -188,7 +191,7 @@ def run(local_rank, world_size):
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
deck
=
init_ygopro
(
"english"
,
args
.
deck
,
args
.
code_list_file
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
args
.
deck2
=
args
.
deck2
or
deck
...
@@ -429,7 +432,8 @@ def run(local_rank, world_size):
...
@@ -429,7 +432,8 @@ def run(local_rank, world_size):
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
collect_time
=
time
.
time
()
-
collect_start
collect_time
=
time
.
time
()
-
collect_start
fprint
(
f
"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
...
@@ -561,16 +565,17 @@ def run(local_rank, world_size):
...
@@ -561,16 +565,17 @@ def run(local_rank, world_size):
train_time
=
time
.
time
()
-
_start
train_time
=
time
.
time
()
-
_start
fprint
(
f
"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}"
)
if
local_rank
==
0
:
fprint
(
f
"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}"
)
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
var_y
=
np
.
var
(
y_true
)
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_
rank
==
0
:
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt
h
"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
@@ -581,15 +586,17 @@ def run(local_rank, world_size):
...
@@ -581,15 +586,17 @@ def run(local_rank, world_size):
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
# Warmup at first few iterations for accurate SPS measurement
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters
=
10
SPS_warmup_iters
=
10
if
iteration
==
SPS_warmup_iters
:
if
iteration
==
SPS_warmup_iters
:
start_time
=
time
.
time
()
start_time
=
time
.
time
()
warmup_steps
=
global_step
warmup_steps
=
global_step
if
iteration
>
SPS_warmup_iters
:
if
iteration
>
SPS_warmup_iters
:
if
local_rank
==
0
:
fprint
(
f
"SPS: {SPS}"
)
fprint
(
f
"SPS: {SPS}"
)
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
iteration
%
args
.
eval_interval
==
0
:
if
iteration
%
args
.
eval_interval
==
0
:
...
@@ -628,11 +635,12 @@ def run(local_rank, world_size):
...
@@ -628,11 +635,12 @@ def run(local_rank, world_size):
# sync the statistics
# sync the statistics
if
args
.
world_size
>
1
:
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
if
local_rank
==
0
:
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
if
rank
==
0
:
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
"charts/eval_ep_len"
,
eval_ep_len
,
global_step
)
writer
.
add_scalar
(
"charts/eval_ep_len"
,
eval_ep_len
,
global_step
)
writer
.
add_scalar
(
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
writer
.
add_scalar
(
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
local_rank
==
0
:
eval_time
=
time
.
time
()
-
_start
eval_time
=
time
.
time
()
-
_start
fprint
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}"
)
fprint
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}"
)
...
@@ -641,10 +649,10 @@ def run(local_rank, world_size):
...
@@ -641,10 +649,10 @@ def run(local_rank, world_size):
if
args
.
world_size
>
1
:
if
args
.
world_size
>
1
:
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
envs
.
close
()
envs
.
close
()
if
local_
rank
==
0
:
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt
h
"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt"
))
writer
.
close
()
writer
.
close
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
m
p_start
(
run
)
m
ain
(
)
scripts/ppo_sp2.py
deleted
100644 → 0
View file @
e5e5402a
This diff is collapsed.
Click to expand it.
scripts/ppo_sp5.py
View file @
ad9980e5
...
@@ -530,7 +530,7 @@ def run(local_rank, world_size):
...
@@ -530,7 +530,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
if
local_rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt
h
"
))
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
@@ -564,7 +564,7 @@ def run(local_rank, world_size):
...
@@ -564,7 +564,7 @@ def run(local_rank, world_size):
agent2
.
load_state_dict
(
agent1
.
state_dict
())
agent2
.
load_state_dict
(
agent1
.
state_dict
())
version
+=
1
version
+=
1
if
local_rank
==
0
:
if
local_rank
==
0
:
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt
h
"
))
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
avg_win_rates
.
clear
()
avg_ep_returns
.
clear
()
avg_ep_returns
.
clear
()
...
@@ -614,7 +614,7 @@ def run(local_rank, world_size):
...
@@ -614,7 +614,7 @@ def run(local_rank, world_size):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
envs
.
close
()
envs
.
close
()
if
local_rank
==
0
:
if
local_rank
==
0
:
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt
h
"
))
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt"
))
writer
.
close
()
writer
.
close
()
...
...
ygoai/rl/agent.py
View file @
ad9980e5
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment