Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
11261948
Commit
11261948
authored
Mar 16, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
(WIP) OSFP
parent
4d07e48e
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
675 additions
and
69 deletions
+675
-69
.gitignore
.gitignore
+3
-0
scripts/battle.py
scripts/battle.py
+2
-2
scripts/eval.py
scripts/eval.py
+1
-1
scripts/ppo.py
scripts/ppo.py
+7
-6
scripts/ppo_lstm.py
scripts/ppo_lstm.py
+1
-1
scripts/ppo_osfp.py
scripts/ppo_osfp.py
+558
-0
scripts/ppo_t.py
scripts/ppo_t.py
+61
-29
ygoai/rl/agent.py
ygoai/rl/agent.py
+40
-27
ygoai/rl/ppo.py
ygoai/rl/ppo.py
+1
-2
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+1
-1
No files found.
.gitignore
View file @
11261948
*.pt
*.pkl
# Xmake cache
# Xmake cache
.xmake/
.xmake/
...
...
scripts/battle.py
View file @
11261948
...
@@ -140,8 +140,8 @@ if __name__ == "__main__":
...
@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list
=
f
.
readlines
()
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
L
=
args
.
num_layers
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
...
...
scripts/eval.py
View file @
11261948
...
@@ -154,7 +154,7 @@ if __name__ == "__main__":
...
@@ -154,7 +154,7 @@ if __name__ == "__main__":
code_list
=
f
.
readlines
()
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
if
args
.
checkpoint
:
if
args
.
checkpoint
:
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
if
not
args
.
compile
:
if
not
args
.
compile
:
...
...
scripts/ppo.py
View file @
11261948
...
@@ -5,6 +5,7 @@ from collections import deque
...
@@ -5,6 +5,7 @@ from collections import deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Literal
,
Optional
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
import
tyro
import
tyro
...
@@ -247,7 +248,7 @@ def main():
...
@@ -247,7 +248,7 @@ def main():
else
:
else
:
embedding_shape
=
None
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent
.
eval
()
agent
.
eval
()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
...
@@ -274,9 +275,9 @@ def main():
...
@@ -274,9 +275,9 @@ def main():
if
args
.
compile
:
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
# predict_step = torch.compile(predict_step, mode=args.compile)
obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
example_
obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
else
:
else
:
...
@@ -389,7 +390,7 @@ def main():
...
@@ -389,7 +390,7 @@ def main():
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
with
torch
.
no_grad
():
value
=
traced_model
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
...
@@ -403,7 +404,7 @@ def main():
...
@@ -403,7 +404,7 @@ def main():
}
}
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
traced_model
(
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
value
=
predict_step
(
traced_model
,
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
advantages
=
bootstrap_value_selfplay
(
...
@@ -420,7 +421,7 @@ def main():
...
@@ -420,7 +421,7 @@ def main():
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
b_returns
=
b_advantages
+
b_values
# Optimizing the policy and value network
# Optimizing the policy and value network
...
...
scripts/ppo_lstm.py
View file @
11261948
...
@@ -243,7 +243,7 @@ def main():
...
@@ -243,7 +243,7 @@ def main():
else
:
else
:
embedding_shape
=
None
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
if
args
.
checkpoint
:
if
args
.
checkpoint
:
agent
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
))
agent
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
))
...
...
scripts/ppo_osfp.py
0 → 100644
View file @
11261948
import
os
import
random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
import
ygoenv
import
numpy
as
np
import
tyro
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.distributions
import
Categorical
import
torch.distributed
as
dist
from
torch.cuda.amp
import
GradScaler
,
autocast
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
to_tensor
,
load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
from
ygoai.rl.eval
import
evaluate
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)[:
-
len
(
".py"
)]
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
torch_deterministic
:
bool
=
False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
16
"""the number of history actions to use"""
play_mode
:
str
=
"bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load the model from"""
total_timesteps
:
int
=
2000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.5e-4
"""the learning rate of the optimizer"""
num_envs
:
int
=
8
"""the number of parallel game environments"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
0.997
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
update_win_rate
:
float
=
0.55
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
"""the required return to update the agent"""
minibatch_size
:
int
=
256
"""the mini-batch size"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.1
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
0.5
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train
:
bool
=
False
"""if toggled, training will be done in fp16 precision"""
fp16_eval
:
bool
=
False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir
:
str
=
"./runs"
"""tensorboard log directory"""
ckpt_dir
:
str
=
"./checkpoints"
"""checkpoint directory"""
save_interval
:
int
=
500
"""the number of iterations to save the model"""
log_p
:
float
=
1.0
"""the probability of logging"""
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
10
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size
:
int
=
0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size
:
int
=
0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs
:
int
=
0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size
:
int
=
0
"""the batch size (computed in runtime)"""
num_iterations
:
int
=
0
"""the number of iterations (computed in runtime)"""
world_size
:
int
=
0
"""the number of processes (computed in runtime)"""
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
print
(
f
"rank={rank}, local_rank={local_rank}, world_size={world_size}"
)
args
=
tyro
.
cli
(
Args
)
args
.
world_size
=
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
)
args
.
local_minibatch_size
=
int
(
args
.
minibatch_size
//
args
.
world_size
)
args
.
batch_size
=
int
(
args
.
num_envs
*
args
.
num_steps
)
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
torch
.
set_num_threads
(
local_torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
torchrun_setup
(
args
.
backend
,
local_rank
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
None
if
rank
==
0
:
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
ckpt_dir
=
os
.
path
.
join
(
args
.
ckpt_dir
,
run_name
)
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args
.
seed
+=
rank
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
-
rank
)
if
args
.
torch_deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
else
:
torch
.
backends
.
cudnn
.
benchmark
=
True
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
args
.
local_num_envs
,
num_threads
=
local_env_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
eval_envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
local_eval_num_envs
,
num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
),
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
eval_envs
.
num_envs
=
local_eval_num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
else
:
embedding_shape
=
None
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent
.
eval
()
if
args
.
checkpoint
:
agent
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
))
fprint
(
f
"Loaded checkpoint from {args.checkpoint}"
)
elif
args
.
embedding_file
:
agent
.
load_embeddings
(
embeddings
)
fprint
(
f
"Loaded embeddings from {args.embedding_file}"
)
if
args
.
embedding_file
:
agent
.
freeze_embeddings
()
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent_t
.
eval
()
agent_t
.
load_state_dict
(
agent
.
state_dict
())
optim_params
=
list
(
agent
.
parameters
())
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
def
predict_step
(
agent
:
Agent
,
next_obs
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
value
,
valid
=
agent
(
next_obs
)
return
logits
,
value
from
ygoai.rl.ppo
import
train_step
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
else
:
traced_model
=
agent
traced_model_t
=
agent_t
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
# TRY NOT TO MODIFY: start the game
global_step
=
0
warmup_steps
=
0
start_time
=
time
.
time
()
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
device
,
dtype
=
torch
.
uint8
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
ai_player1_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value1
=
next_value2
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
frac
=
1.0
-
(
iteration
-
1.0
)
/
args
.
num_iterations
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
model_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
while
step
<
args
.
collect_length
:
global_step
+=
args
.
num_envs
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
learn
=
next_to_play
==
ai_player1
learns
[
step
]
=
learn
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
logprob
=
probs
.
log_prob
(
action
)
values
[
step
]
=
value
actions
[
step
]
=
action
logprobs
[
step
]
=
logprob
action
=
action
.
cpu
()
.
numpy
()
model_time
+=
time
.
time
()
-
_start
next_nonterminal
=
1
-
next_done
.
float
()
next_value1
=
torch
.
where
(
learn
,
value
,
next_value1
)
*
next_nonterminal
next_value2
=
torch
.
where
(
learn
,
next_value2
,
value
)
*
next_nonterminal
_start
=
time
.
time
()
to_play
=
next_to_play_
next_obs
,
reward
,
next_done_
,
info
=
envs
.
step
(
action
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
pl
=
1
if
to_play
[
idx
]
==
ai_player1_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
if
random
.
random
()
<
args
.
log_p
:
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
2
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
random
.
random
()
<
1
/
n
:
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
collect_time
=
time
.
time
()
-
collect_start
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
_start
=
time
.
time
()
# bootstrap value if not done
with
torch
.
no_grad
():
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
1
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
predict_step
(
traced_model
,
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
bootstrap_time
=
time
.
time
()
-
_start
_start
=
time
.
time
()
# flatten the batch
b_obs
=
{
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
}
b_actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
clipfracs
=
[]
for
epoch
in
range
(
args
.
update_epochs
):
np
.
random
.
shuffle
(
b_inds
)
for
start
in
range
(
0
,
args
.
local_batch_size
,
args
.
local_minibatch_size
):
end
=
start
+
args
.
local_minibatch_size
mb_inds
=
b_inds
[
start
:
end
]
mb_obs
=
{
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
}
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
train_step
(
agent
,
optimizer
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
],
b_learns
[
mb_inds
],
args
)
reduce_gradidents
(
optim_params
,
args
.
world_size
)
nn
.
utils
.
clip_grad_norm_
(
optim_params
,
args
.
max_grad_norm
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
fprint
(
f
"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}"
)
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/old_approx_kl"
,
old_approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters
=
10
if
iteration
==
SPS_warmup_iters
:
start_time
=
time
.
time
()
warmup_steps
=
global_step
if
iteration
>
SPS_warmup_iters
:
if
local_rank
==
0
:
fprint
(
f
"SPS: {SPS}"
)
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
rank
==
0
:
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
else
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
should_update
=
should_update
.
item
()
>
0
if
should_update
:
agent_t
.
load_state_dict
(
agent
.
state_dict
())
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
version
+=
1
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
avg_ep_returns
.
clear
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
eval_return
=
eval_stats
.
cpu
()
.
numpy
()
if
rank
==
0
:
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
if
local_rank
==
0
:
eval_time
=
time
.
time
()
-
_start
fprint
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}"
)
# Eval with old model
if
args
.
world_size
>
1
:
dist
.
destroy_process_group
()
envs
.
close
()
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt"
))
writer
.
close
()
if
__name__
==
"__main__"
:
main
()
scripts/ppo_t.py
View file @
11261948
...
@@ -102,8 +102,10 @@ class Args:
...
@@ -102,8 +102,10 @@ class Args:
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
target_kl
:
Optional
[
float
]
=
None
target_kl
:
Optional
[
float
]
=
None
"""the target KL divergence threshold"""
"""the target KL divergence threshold"""
learn_opponent
:
bool
=
Tru
e
learn_opponent
:
bool
=
Fals
e
"""if toggled, the samples from the opponent will be used to train the agent"""
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
"""the backend for distributed training"""
...
@@ -161,6 +163,9 @@ def main():
...
@@ -161,6 +163,9 @@ def main():
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
...
@@ -248,7 +253,7 @@ def main():
...
@@ -248,7 +253,7 @@ def main():
else
:
else
:
embedding_shape
=
None
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent
.
eval
()
agent
.
eval
()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
...
@@ -260,22 +265,19 @@ def main():
...
@@ -260,22 +265,19 @@ def main():
if
args
.
embedding_file
:
if
args
.
embedding_file
:
agent
.
freeze_embeddings
()
agent
.
freeze_embeddings
()
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent_t
.
eval
()
agent_t
.
load_state_dict
(
agent
.
state_dict
())
optim_params
=
list
(
agent
.
parameters
())
optim_params
=
list
(
agent
.
parameters
())
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
def
predict_step
(
agent
:
Agent
,
next_obs
):
agent_t
.
eval
()
agent_t
.
load_state_dict
(
agent
.
state_dict
())
def
predict_step
(
agent
:
Agent
,
agent_t
:
Agent
,
next_obs
,
learn
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
value
,
valid
=
agent
(
next_obs
)
logits
,
value
,
valid
=
agent
(
next_obs
)
logits_t
,
value_t
,
valid
=
agent_t
(
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
return
logits
,
value
return
logits
,
value
from
ygoai.rl.ppo
import
train_step
from
ygoai.rl.ppo
import
train_step
...
@@ -289,15 +291,18 @@ def main():
...
@@ -289,15 +291,18 @@ def main():
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
else
:
traced_model
=
agent
traced_model_t
=
agent_t
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
version
=
0
...
@@ -318,6 +323,7 @@ def main():
...
@@ -318,6 +323,7 @@ def main():
np
.
random
.
shuffle
(
ai_player1_
)
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value
=
0
next_value
=
0
step
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
...
@@ -329,7 +335,7 @@ def main():
...
@@ -329,7 +335,7 @@ def main():
model_time
=
0
model_time
=
0
env_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
collect_start
=
time
.
time
()
for
step
in
range
(
0
,
args
.
num_steps
)
:
while
step
<
args
.
collect_length
:
global_step
+=
args
.
num_envs
global_step
+=
args
.
num_envs
for
key
in
obs
:
for
key
in
obs
:
...
@@ -339,7 +345,10 @@ def main():
...
@@ -339,7 +345,10 @@ def main():
learns
[
step
]
=
learn
learns
[
step
]
=
learn
_start
=
time
.
time
()
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
traced_model
,
traced_model_t
,
next_obs
,
learn
)
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
value
=
value
.
flatten
()
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
action
=
probs
.
sample
()
...
@@ -362,6 +371,7 @@ def main():
...
@@ -362,6 +371,7 @@ def main():
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
if
not
writer
:
continue
continue
...
@@ -390,6 +400,8 @@ def main():
...
@@ -390,6 +400,8 @@ def main():
if
local_rank
==
0
:
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -397,23 +409,36 @@ def main():
...
@@ -397,23 +409,36 @@ def main():
value_t
=
traced_model_t
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
traced_model_t
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
nextvalues
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value
)
nextvalues
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value
)
if
step
>
0
and
iteration
!=
1
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
predict_step
(
traced_model
,
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_self
(
advantages
=
bootstrap_value_self
(
values
,
rewards
,
dones
,
learns
,
nextvalues
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
values
,
rewards
,
dones
,
learns
,
nextvalues
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
returns
=
advantages
+
values
bootstrap_time
=
time
.
time
()
-
_start
bootstrap_time
=
time
.
time
()
-
_start
_start
=
time
.
time
()
_start
=
time
.
time
()
# flatten the batch
# flatten the batch
b_obs
=
{
b_obs
=
{
k
:
v
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
for
k
,
v
in
obs
.
items
()
}
}
b_
logprobs
=
logprobs
.
reshape
(
-
1
)
b_
actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
b_
actions
=
actions
.
reshape
((
-
1
,)
+
action_shape
)
b_
logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
returns
=
returns
.
reshape
(
-
1
)
b_
values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
values
=
values
.
reshape
(
-
1
)
b_
learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
if
args
.
learn_opponent
else
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_
learns
=
learns
.
reshape
(
-
1
)
b_
returns
=
b_advantages
+
b_values
# Optimizing the policy and value network
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
@@ -437,7 +462,14 @@ def main():
...
@@ -437,7 +462,14 @@ def main():
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
if
args
.
target_kl
is
not
None
and
approx_kl
>
args
.
target_kl
:
break
break
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
train_time
=
time
.
time
()
-
_start
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -497,7 +529,7 @@ def main():
...
@@ -497,7 +529,7 @@ def main():
_start
=
time
.
time
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_return
=
evaluate
(
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)
[
0
]
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
# sync the statistics
...
...
ygoai/rl/agent.py
View file @
11261948
...
@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
...
@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
super
(
Encoder
,
self
)
.
__init__
()
super
(
Encoder
,
self
)
.
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
num_history_action_layers
=
num_history_action_layers
c
=
channels
c
=
channels
self
.
loc_embed
=
nn
.
Embedding
(
9
,
c
)
self
.
loc_embed
=
nn
.
Embedding
(
9
,
c
)
...
@@ -165,11 +163,17 @@ class Encoder(nn.Module):
...
@@ -165,11 +163,17 @@ class Encoder(nn.Module):
for
i
in
range
(
num_action_layers
)
for
i
in
range
(
num_action_layers
)
])
])
self
.
action_history_pe
=
PositionalEncoding
(
c
,
dropout
=
0.0
)
self
.
history_action_pe
=
PositionalEncoding
(
c
,
dropout
=
0.0
)
self
.
history_action_net
=
nn
.
ModuleList
([
nn
.
TransformerEncoderLayer
(
c
,
num_heads
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
)
for
i
in
range
(
num_action_layers
)
])
self
.
action_history_net
=
nn
.
ModuleList
([
self
.
action_history_net
=
nn
.
ModuleList
([
nn
.
TransformerDecoderLayer
(
nn
.
TransformerDecoderLayer
(
c
,
num_heads
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
c
,
num_heads
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
for
i
in
range
(
num_
history_
action_layers
)
for
i
in
range
(
num_action_layers
)
])
])
self
.
action_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
action_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
...
@@ -287,6 +291,7 @@ class Encoder(nn.Module):
...
@@ -287,6 +291,7 @@ class Encoder(nn.Module):
x_cards
=
x
[
'cards_'
]
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_actions
=
x
[
'actions_'
]
batch_size
=
x_cards
.
shape
[
0
]
x_cards_1
=
x_cards
[:,
:,
:
12
]
.
long
()
x_cards_1
=
x_cards
[:,
:,
:
12
]
.
long
()
x_cards_2
=
x_cards
[:,
:,
12
:]
.
to
(
torch
.
float32
)
x_cards_2
=
x_cards
[:,
:,
12
:]
.
to
(
torch
.
float32
)
...
@@ -294,7 +299,10 @@ class Encoder(nn.Module):
...
@@ -294,7 +299,10 @@ class Encoder(nn.Module):
x_id
=
self
.
encode_card_id
(
x_cards_1
[:,
:,
:
2
])
x_id
=
self
.
encode_card_id
(
x_cards_1
[:,
:,
:
2
])
x_id
=
self
.
id_norm
(
x_id
)
x_id
=
self
.
id_norm
(
x_id
)
f_loc
=
self
.
loc_norm
(
self
.
loc_embed
(
x_cards_1
[:,
:,
2
]))
x_loc
=
x_cards_1
[:,
:,
2
]
c_mask
=
x_loc
==
0
c_mask
[:,
0
]
=
False
f_loc
=
self
.
loc_norm
(
self
.
loc_embed
(
x_loc
))
f_seq
=
self
.
seq_norm
(
self
.
seq_embed
(
x_cards_1
[:,
:,
3
]))
f_seq
=
self
.
seq_norm
(
self
.
seq_embed
(
x_cards_1
[:,
:,
3
]))
x_feat1
=
self
.
encode_card_feat1
(
x_cards_1
)
x_feat1
=
self
.
encode_card_feat1
(
x_cards_1
)
...
@@ -306,11 +314,14 @@ class Encoder(nn.Module):
...
@@ -306,11 +314,14 @@ class Encoder(nn.Module):
f_cards
=
torch
.
cat
([
x_id
,
x_feat
],
dim
=-
1
)
f_cards
=
torch
.
cat
([
x_id
,
x_feat
],
dim
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards
=
f_cards
+
f_loc
+
f_seq
f_na_card
=
self
.
na_card_embed
.
expand
(
f_cards
.
shape
[
0
],
-
1
,
-
1
)
for
layer
in
self
.
card_net
:
# f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_cards
=
layer
(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_na_card
=
self
.
na_card_embed
.
expand
(
batch_size
,
-
1
,
-
1
)
f_cards
=
torch
.
cat
([
f_na_card
,
f_cards
],
dim
=
1
)
f_cards
=
torch
.
cat
([
f_na_card
,
f_cards
],
dim
=
1
)
# TODO: we can't use it because cudagraph says complex memory
# c_mask = torch.cat([torch.zeros(batch_size, 1, dtype=c_mask.dtype, device=c_mask.device), c_mask], dim=1)
for
layer
in
self
.
card_net
:
f_cards
=
layer
(
f_cards
)
f_cards
=
self
.
card_norm
(
f_cards
)
f_cards
=
self
.
card_norm
(
f_cards
)
x_global
=
self
.
encode_global
(
x_global
)
x_global
=
self
.
encode_global
(
x_global
)
...
@@ -334,21 +345,24 @@ class Encoder(nn.Module):
...
@@ -334,21 +345,24 @@ class Encoder(nn.Module):
valid
=
x
[
'global_'
][:,
-
1
]
==
0
valid
=
x
[
'global_'
][:,
-
1
]
==
0
mask
[:,
0
]
&=
valid
mask
[:,
0
]
&=
valid
for
layer
in
self
.
action_card_net
:
for
layer
in
self
.
action_card_net
:
f_actions
=
layer
(
f_actions
,
f_cards
,
tgt_key_padding_mask
=
mask
)
f_actions
=
layer
(
f_actions
,
f_cards
[:,
1
:],
tgt_key_padding_mask
=
mask
,
memory_key_padding_mask
=
c_mask
)
if
self
.
num_history_action_layers
!=
0
:
x_h_actions
=
x
[
'h_actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
x_h_actions
=
x_h_actions
.
long
()
x_h_actions
=
x_h_actions
.
long
()
x_h_id
=
self
.
get_h_action_card_
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
self
.
get_h_action_card_
(
x_h_actions
[
...
,
:
2
])
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
[:,
0
]
=
False
x_h_a_feats
=
self
.
encode_action_
(
x_h_actions
[:,
:,
2
:])
x_h_a_feats
=
self
.
encode_action_
(
x_h_actions
[:,
:,
2
:])
x_h_a_feats
=
torch
.
cat
(
x_h_a_feats
,
dim
=-
1
)
x_h_a_feats
=
torch
.
cat
(
x_h_a_feats
,
dim
=-
1
)
f_h_actions
=
self
.
h_id_norm
(
x_h_id
)
+
self
.
h_a_feat_norm
(
x_h_a_feats
)
f_h_actions
=
self
.
h_id_norm
(
x_h_id
)
+
self
.
h_a_feat_norm
(
x_h_a_feats
)
f_h_actions
=
self
.
action_history_pe
(
f_h_actions
)
f_h_actions
=
self
.
history_action_pe
(
f_h_actions
)
for
layer
in
self
.
action_history_net
:
for
layer
in
self
.
history_action_net
:
f_actions
=
layer
(
f_actions
,
f_h_actions
)
f_h_actions
=
layer
(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
for
layer
in
self
.
action_history_net
:
f_actions
=
layer
(
f_actions
,
f_h_actions
,
tgt_key_padding_mask
=
mask
,
memory_key_padding_mask
=
h_mask
)
f_actions
=
self
.
action_norm
(
f_actions
)
f_actions
=
self
.
action_norm
(
f_actions
)
...
@@ -385,13 +399,12 @@ class Actor(nn.Module):
...
@@ -385,13 +399,12 @@ class Actor(nn.Module):
class
PPOAgent
(
nn
.
Module
):
class
PPOAgent
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
,
a_trans
=
True
):
affine
=
True
,
a_trans
=
True
):
super
(
PPOAgent
,
self
)
.
__init__
()
super
(
PPOAgent
,
self
)
.
__init__
()
self
.
encoder
=
Encoder
(
self
.
encoder
=
Encoder
(
channels
,
num_card_layers
,
num_action_layers
,
num_history_action_layers
,
embedding_shape
,
bias
,
affine
)
channels
,
num_card_layers
,
num_action_layers
,
embedding_shape
,
bias
,
affine
)
c
=
channels
c
=
channels
self
.
actor
=
Actor
(
c
,
a_trans
)
self
.
actor
=
Actor
(
c
,
a_trans
)
...
...
ygoai/rl/ppo.py
View file @
11261948
...
@@ -11,8 +11,7 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
...
@@ -11,8 +11,7 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
newlogprob
=
probs
.
log_prob
(
mb_actions
)
newlogprob
=
probs
.
log_prob
(
mb_actions
)
entropy
=
probs
.
entropy
()
entropy
=
probs
.
entropy
()
if
not
args
.
learn_opponent
:
valid
=
torch
.
logical_and
(
valid
,
mb_learns
)
valid
=
torch
.
logical_and
(
valid
,
mb_learns
)
logratio
=
newlogprob
-
mb_logprobs
logratio
=
newlogprob
-
mb_logprobs
ratio
=
logratio
.
exp
()
ratio
=
logratio
.
exp
()
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
11261948
...
@@ -1870,10 +1870,10 @@ private:
...
@@ -1870,10 +1870,10 @@ private:
std
::
tuple
<
SpecIndex
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
std
::
tuple
<
SpecIndex
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
SpecIndex
spec2index
;
SpecIndex
spec2index
;
std
::
vector
<
int
>
loc_n_cards
;
std
::
vector
<
int
>
loc_n_cards
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
bool
opponent
=
pi
==
1
;
const
bool
opponent
=
pi
==
1
;
int
offset
=
opponent
?
spec_
.
config
[
"max_cards"
_
]
:
0
;
std
::
vector
<
std
::
pair
<
uint8_t
,
bool
>>
configs
=
{
std
::
vector
<
std
::
pair
<
uint8_t
,
bool
>>
configs
=
{
{
LOCATION_DECK
,
true
},
{
LOCATION_HAND
,
true
},
{
LOCATION_DECK
,
true
},
{
LOCATION_HAND
,
true
},
{
LOCATION_MZONE
,
false
},
{
LOCATION_SZONE
,
false
},
{
LOCATION_MZONE
,
false
},
{
LOCATION_SZONE
,
false
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment