Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
80707a8c
Commit
80707a8c
authored
Mar 16, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add OSFP
parent
11261948
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
154 deletions
+168
-154
.gitignore
.gitignore
+1
-0
scripts/ppo.py
scripts/ppo.py
+40
-56
scripts/ppo_osfp.py
scripts/ppo_osfp.py
+127
-98
No files found.
.gitignore
View file @
80707a8c
*.pt
*.pt
*.ptj
*.pkl
*.pkl
# Xmake cache
# Xmake cache
...
...
scripts/ppo.py
View file @
80707a8c
...
@@ -3,7 +3,7 @@ import random
...
@@ -3,7 +3,7 @@ import random
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Optional
import
ygoenv
import
ygoenv
...
@@ -52,10 +52,8 @@ class Args:
...
@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings"""
"""the embedding file for card embeddings"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
16
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
play_mode
:
str
=
"bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers
:
int
=
2
num_layers
:
int
=
2
"""the number of layers for the agent"""
"""the number of layers for the agent"""
...
@@ -74,9 +72,9 @@ class Args:
...
@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
0.997
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
gae_lambda
:
float
=
0.9
5
gae_lambda
:
float
=
0.9
8
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
minibatch_size
:
int
=
256
minibatch_size
:
int
=
256
...
@@ -85,7 +83,7 @@ class Args:
...
@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.
1
clip_coef
:
float
=
0.
2
"""the surrogate clipping coefficient"""
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...
@@ -93,17 +91,13 @@ class Args:
...
@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
"""coefficient of the value function"""
max_grad_norm
:
float
=
0.5
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
learn_opponent
:
bool
=
False
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
"""if toggled, the samples from the opponent will be used to train the agent"""
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
Optional
[
int
]
=
None
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
Optional
[
str
]
=
None
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
torch_threads
:
Optional
[
int
]
=
None
...
@@ -125,7 +119,7 @@ class Args:
...
@@ -125,7 +119,7 @@ class Args:
"""the probability of logging"""
"""the probability of logging"""
eval_episodes
:
int
=
128
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
1
0
eval_interval
:
int
=
5
0
"""the number of iterations to evaluate the model"""
"""the number of iterations to evaluate the model"""
# to be filled in runtime
# to be filled in runtime
...
@@ -143,6 +137,23 @@ class Args:
...
@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)"""
"""the number of processes (computed in runtime)"""
def
make_env
(
args
,
num_envs
,
num_threads
,
mode
=
'self'
):
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
return
envs
def
main
():
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
...
@@ -169,7 +180,7 @@ def main():
...
@@ -169,7 +180,7 @@ def main():
torch
.
set_float32_matmul_precision
(
'high'
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
if
args
.
world_size
>
1
:
torchrun_setup
(
args
.
backend
,
local_rank
)
torchrun_setup
(
'nccl'
,
local_rank
)
timestamp
=
int
(
time
.
time
())
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...
@@ -204,43 +215,17 @@ def main():
...
@@ -204,43 +215,17 @@ def main():
args
.
deck2
=
args
.
deck2
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
# env setup
envs
=
ygoenv
.
make
(
envs
=
make_env
(
args
,
args
.
local_num_envs
,
local_env_threads
)
task_id
=
args
.
env_id
,
obs_space
=
envs
.
env
.
observation_space
env_type
=
"gymnasium"
,
action_shape
=
envs
.
env
.
action_space
.
shape
num_envs
=
args
.
local_num_envs
,
num_threads
=
local_env_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
if
local_rank
==
0
:
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
local_eval_num_envs
=
local_eval_episodes
eval_envs
=
ygoenv
.
make
(
local_eval_num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
)
task_id
=
args
.
env_id
,
eval_envs
=
make_env
(
args
,
local_eval_num_envs
,
local_eval_num_threads
,
mode
=
'bot'
)
env_type
=
"gymnasium"
,
num_envs
=
local_eval_num_envs
,
num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
),
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
eval_envs
.
num_envs
=
local_eval_num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
if
args
.
embedding_file
:
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
...
@@ -312,10 +297,10 @@ def main():
...
@@ -312,10 +297,10 @@ def main():
next_value1
=
next_value2
=
0
next_value1
=
next_value2
=
0
step
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
args
.
num_iterations
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
if
args
.
anneal_lr
:
frac
=
1.0
-
(
iteration
-
1.0
)
/
args
.
num_iterations
frac
=
1.0
-
iteration
/
args
.
num_iterations
lrnow
=
frac
*
args
.
learning_rate
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
...
@@ -372,7 +357,7 @@ def main():
...
@@ -372,7 +357,7 @@ def main():
if
random
.
random
()
<
args
.
log_p
:
if
random
.
random
()
<
args
.
log_p
:
n
=
100
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
2
:
if
random
.
random
()
<
10
/
n
or
iteration
<=
1
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
...
@@ -394,7 +379,7 @@ def main():
...
@@ -394,7 +379,7 @@ def main():
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
for
v_start
in
range
(
0
,
step
,
v_steps
):
...
@@ -421,8 +406,11 @@ def main():
...
@@ -421,8 +406,11 @@ def main():
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
b_returns
=
b_advantages
+
b_values
if
args
.
learn_opponent
:
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
else
:
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
# Optimizing the policy and value network
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
@@ -444,9 +432,6 @@ def main():
...
@@ -444,9 +432,6 @@ def main():
scaler
.
update
()
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
if
step
>
0
:
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
for
v
in
obs
.
values
():
...
@@ -463,7 +448,6 @@ def main():
...
@@ -463,7 +448,6 @@ def main():
var_y
=
np
.
var
(
y_true
)
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
rank
==
0
:
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
...
@@ -490,7 +474,7 @@ def main():
...
@@ -490,7 +474,7 @@ def main():
if
rank
==
0
:
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
iteration
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
# Eval with rule-based policy
_start
=
time
.
time
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_return
=
evaluate
(
...
...
scripts/ppo_osfp.py
View file @
80707a8c
import
os
import
os
import
random
import
random
import
time
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Optional
import
ygoenv
import
ygoenv
...
@@ -52,10 +51,8 @@ class Args:
...
@@ -52,10 +51,8 @@ class Args:
"""the embedding file for card embeddings"""
"""the embedding file for card embeddings"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
16
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
play_mode
:
str
=
"bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers
:
int
=
2
num_layers
:
int
=
2
"""the number of layers for the agent"""
"""the number of layers for the agent"""
...
@@ -74,15 +71,21 @@ class Args:
...
@@ -74,15 +71,21 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
0.997
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
gae_lambda
:
float
=
0.9
5
gae_lambda
:
float
=
0.9
8
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
update_win_rate
:
float
=
0.55
update_win_rate
:
float
=
0.55
"""the required win rate to update the agent"""
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
self_play_prob
:
float
=
0.6
"""the required return to update the agent"""
"""the probability of self play"""
max_lp
:
int
=
6
"""the maximum number of LP to add model to the pool"""
iter_per_lp
:
int
=
1000
"""the number of iterations per learning phase"""
target_sample_iter
:
int
=
10
"""the number of iterations to sample the target model"""
minibatch_size
:
int
=
256
minibatch_size
:
int
=
256
"""the mini-batch size"""
"""the mini-batch size"""
...
@@ -90,7 +93,7 @@ class Args:
...
@@ -90,7 +93,7 @@ class Args:
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.
1
clip_coef
:
float
=
0.
2
"""the surrogate clipping coefficient"""
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...
@@ -98,17 +101,13 @@ class Args:
...
@@ -98,17 +101,13 @@ class Args:
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
"""coefficient of the value function"""
max_grad_norm
:
float
=
0.5
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
learn_opponent
:
bool
=
False
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
"""if toggled, the samples from the opponent will be used to train the agent"""
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
Optional
[
int
]
=
None
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
Optional
[
str
]
=
None
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
torch_threads
:
Optional
[
int
]
=
None
...
@@ -130,7 +129,7 @@ class Args:
...
@@ -130,7 +129,7 @@ class Args:
"""the probability of logging"""
"""the probability of logging"""
eval_episodes
:
int
=
128
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
1
0
eval_interval
:
int
=
5
0
"""the number of iterations to evaluate the model"""
"""the number of iterations to evaluate the model"""
# to be filled in runtime
# to be filled in runtime
...
@@ -148,6 +147,27 @@ class Args:
...
@@ -148,6 +147,27 @@ class Args:
"""the number of processes (computed in runtime)"""
"""the number of processes (computed in runtime)"""
def
make_env
(
args
,
num_envs
,
num_threads
,
mode
=
'self'
):
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
return
envs
def
update_running_mean
(
mean
,
value
,
count
):
return
mean
+
(
value
-
mean
)
/
count
def
main
():
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
...
@@ -174,7 +194,19 @@ def main():
...
@@ -174,7 +194,19 @@ def main():
torch
.
set_float32_matmul_precision
(
'high'
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
if
args
.
world_size
>
1
:
torchrun_setup
(
args
.
backend
,
local_rank
)
torchrun_setup
(
'nccl'
,
local_rank
)
def
sync_var
(
var
,
dtype
=
torch
.
float32
,
reduce
=
'first'
):
ts
=
torch
.
tensor
(
var
,
dtype
=
dtype
,
device
=
device
)
if
reduce
==
'mean'
:
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
ts
,
op
=
dist
.
ReduceOp
.
AVG
)
else
:
if
rank
!=
0
:
ts
=
torch
.
zeros_like
(
ts
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
ts
,
op
=
dist
.
ReduceOp
.
SUM
)
return
ts
.
cpu
()
.
numpy
()
timestamp
=
int
(
time
.
time
())
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...
@@ -209,43 +241,17 @@ def main():
...
@@ -209,43 +241,17 @@ def main():
args
.
deck2
=
args
.
deck2
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
# env setup
envs
=
ygoenv
.
make
(
envs
=
make_env
(
args
,
args
.
local_num_envs
,
local_env_threads
)
task_id
=
args
.
env_id
,
obs_space
=
envs
.
env
.
observation_space
env_type
=
"gymnasium"
,
action_shape
=
envs
.
env
.
action_space
.
shape
num_envs
=
args
.
local_num_envs
,
num_threads
=
local_env_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
if
local_rank
==
0
:
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
local_eval_num_envs
=
local_eval_episodes
eval_envs
=
ygoenv
.
make
(
local_eval_num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
)
task_id
=
args
.
env_id
,
eval_envs
=
make_env
(
args
,
local_eval_num_envs
,
local_eval_num_threads
,
mode
=
'bot'
)
env_type
=
"gymnasium"
,
num_envs
=
local_eval_num_envs
,
num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
),
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
eval_envs
.
num_envs
=
local_eval_num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
if
args
.
embedding_file
:
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
...
@@ -280,6 +286,8 @@ def main():
...
@@ -280,6 +286,8 @@ def main():
logits
,
value
,
valid
=
agent
(
next_obs
)
logits
,
value
,
valid
=
agent
(
next_obs
)
return
logits
,
value
return
logits
,
value
history
=
[]
from
ygoai.rl.ppo
import
train_step
from
ygoai.rl.ppo
import
train_step
if
args
.
compile
:
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
...
@@ -287,13 +295,23 @@ def main():
...
@@ -287,13 +295,23 @@ def main():
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
if
args
.
checkpoint
:
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
history
.
append
(
traced_model_t
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
else
:
traced_model
=
agent
def
sample_target
(
history
):
traced_model_t
=
agent_t
ts
=
[]
for
i
in
range
(
args
.
target_sample_iter
):
if
len
(
history
)
==
0
or
random
.
random
()
<
args
.
self_play_prob
:
ts
.
append
(
-
1
)
else
:
ts
.
append
(
random
.
randint
(
0
,
len
(
history
)
-
1
))
ts
.
sort
(
reverse
=
True
)
return
sync_var
(
ts
,
dtype
=
torch
.
int64
)
.
tolist
()
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
...
@@ -303,9 +321,9 @@ def main():
...
@@ -303,9 +321,9 @@ def main():
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
[
0
]
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
[
0
]
version
=
0
n_episodes
=
[
0
]
# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step
=
0
global_step
=
0
...
@@ -324,14 +342,23 @@ def main():
...
@@ -324,14 +342,23 @@ def main():
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value1
=
next_value2
=
0
next_value1
=
next_value2
=
0
step
=
0
step
=
0
lp_count
=
0
ts
=
sample_target
(
history
)
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
args
.
num_iterations
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
if
args
.
anneal_lr
:
frac
=
1.0
-
(
iteration
-
1.0
)
/
args
.
num_iterations
frac
=
1.0
-
(
iteration
%
args
.
iter_per_lp
)
/
args
.
iter_per_lp
lrnow
=
frac
*
args
.
learning_rate
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
if
len
(
ts
)
==
0
:
ts
=
sample_target
(
history
)
t_idx
=
ts
.
pop
()
selfplay
=
t_idx
==
-
1
if
not
selfplay
:
traced_model_t
=
history
[
t_idx
]
model_time
=
0
model_time
=
0
env_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
collect_start
=
time
.
time
()
...
@@ -346,9 +373,10 @@ def main():
...
@@ -346,9 +373,10 @@ def main():
_start
=
time
.
time
()
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
if
not
selfplay
:
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
value
=
value
.
flatten
()
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
action
=
probs
.
sample
()
...
@@ -374,21 +402,20 @@ def main():
...
@@ -374,21 +402,20 @@ def main():
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
step
+=
1
if
not
writer
:
continue
for
idx
,
d
in
enumerate
(
next_done_
):
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
if
d
:
pl
=
1
if
to_play
[
idx
]
==
ai_player1_
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
ai_player1_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
if
len
(
history
)
==
0
or
not
selfplay
:
avg_win_rates
.
append
(
win
)
n_episodes
[
t_idx
]
+=
1
avg_ep_returns
[
t_idx
]
=
update_running_mean
(
avg_ep_returns
[
t_idx
],
episode_reward
,
n_episodes
[
t_idx
])
avg_win_rates
[
t_idx
]
=
update_running_mean
(
avg_win_rates
[
t_idx
],
win
,
n_episodes
[
t_idx
])
if
random
.
random
()
<
args
.
log_p
:
if
writer
and
random
.
random
()
<
args
.
log_p
:
n
=
100
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
2
:
if
random
.
random
()
<
10
/
n
or
iteration
<=
1
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
...
@@ -407,12 +434,13 @@ def main():
...
@@ -407,12 +434,13 @@ def main():
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
with
torch
.
no_grad
():
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
if
not
selfplay
:
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
for
v_start
in
range
(
0
,
step
,
v_steps
):
...
@@ -439,8 +467,11 @@ def main():
...
@@ -439,8 +467,11 @@ def main():
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
b_returns
=
b_advantages
+
b_values
if
args
.
learn_opponent
or
selfplay
:
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
else
:
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
# Optimizing the policy and value network
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
@@ -462,9 +493,6 @@ def main():
...
@@ -462,9 +493,6 @@ def main():
scaler
.
update
()
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
if
step
>
0
:
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
for
v
in
obs
.
values
():
...
@@ -481,7 +509,6 @@ def main():
...
@@ -481,7 +509,6 @@ def main():
var_y
=
np
.
var
(
y_true
)
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
rank
==
0
:
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
...
@@ -508,27 +535,29 @@ def main():
...
@@ -508,27 +535,29 @@ def main():
if
rank
==
0
:
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
rank
==
0
:
if
(
iteration
+
1
)
%
args
.
iter_per_lp
==
0
:
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
lp_count
+=
1
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
win_rates
=
sync_var
(
avg_win_rates
,
dtype
=
torch
.
float32
,
reduce
=
'mean'
)
else
:
if
np
.
all
(
win_rates
>
args
.
update_win_rate
)
or
lp_count
>=
args
.
max_lp
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
agent_t
.
load_state_dict
(
agent
.
state_dict
())
if
args
.
world_size
>
1
:
with
torch
.
no_grad
():
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
should_update
=
should_update
.
item
()
>
0
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
if
should_update
:
history
.
append
(
traced_model_t
)
agent_t
.
load_state_dict
(
agent
.
state_dict
())
lp_count
=
0
with
torch
.
no_grad
():
if
rank
==
0
:
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
version
=
len
(
history
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
fprint
(
f
"model v{version} added to the pool, win_rates={win_rates}"
)
version
+=
1
else
:
if
rank
==
0
:
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
fprint
(
f
"win_rates={win_rates}, not updating the pool"
)
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_ep_returns
=
[
0
]
*
len
(
history
)
avg_win_rates
.
clear
()
avg_win_rates
=
[
0
]
*
len
(
history
)
avg_ep_returns
.
clear
()
n_episodes
=
[
0
]
*
len
(
history
)
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_return
=
evaluate
(
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment