Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
385bd1cb
Commit
385bd1cb
authored
Feb 27, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add PPO selfplay
parent
af60d012
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
296 additions
and
141 deletions
+296
-141
scripts/eval.py
scripts/eval.py
+9
-6
scripts/ppo_sp5.py
scripts/ppo_sp5.py
+211
-71
ygoai/rl/agent.py
ygoai/rl/agent.py
+61
-48
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+15
-16
No files found.
scripts/eval.py
View file @
385bd1cb
...
...
@@ -147,21 +147,21 @@ if __name__ == "__main__":
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent
=
agent
.
eval
()
#
agent = agent.eval()
if
args
.
checkpoint
:
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
else
:
state_dict
=
None
if
args
.
compile
:
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
if
state_dict
:
agent
.
load_state_dict
(
state_dict
)
print
(
agent
.
load_state_dict
(
state_dict
))
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
else
:
prefix
=
"_orig_mod."
if
state_dict
:
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
agent
.
load_state_dict
(
state_dict
)
print
(
agent
.
load_state_dict
(
state_dict
)
)
if
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
...
...
@@ -170,6 +170,7 @@ if __name__ == "__main__":
agent
=
torch
.
jit
.
optimize_for_inference
(
traced_model
)
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
episode_rewards
=
[]
episode_lengths
=
[]
...
...
@@ -191,7 +192,7 @@ if __name__ == "__main__":
_start
=
time
.
time
()
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
with
torch
.
no_grad
():
logits
,
values
=
agent
(
obs
)
logits
,
values
,
_valid
=
agent
(
obs
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
probs
.
cpu
()
.
numpy
()
if
args
.
play
:
...
...
@@ -212,9 +213,11 @@ if __name__ == "__main__":
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play
=
next_to_play
_start
=
time
.
time
()
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
env_time
+=
time
.
time
()
-
_start
step
+=
1
...
...
@@ -225,7 +228,7 @@ if __name__ == "__main__":
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
if
args
.
selfplay
:
pl
=
1
if
infos
[
'to_play'
]
[
idx
]
==
0
else
-
1
pl
=
1
if
to_play
[
idx
]
==
0
else
-
1
winner
=
0
if
episode_reward
*
pl
>
0
else
1
win
=
1
-
winner
else
:
...
...
scripts/ppo_sp.py
→
scripts/ppo_sp
5
.py
View file @
385bd1cb
import
os
import
random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
...
...
@@ -52,7 +53,7 @@ class Args:
"""the maximum number of options"""
n_history_actions
:
int
=
16
"""the number of history actions to use"""
play_mode
:
str
=
"
self+
bot"
play_mode
:
str
=
"bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers
:
int
=
2
...
...
@@ -74,6 +75,12 @@ class Args:
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
update_win_rate
:
float
=
0.6
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
"""the required return to update the agent"""
minibatch_size
:
int
=
256
"""the mini-batch size"""
update_epochs
:
int
=
2
...
...
@@ -95,10 +102,8 @@ class Args:
backend
:
Literal
[
"gloo"
,
"nccl"
,
"mpi"
]
=
"nccl"
"""the backend for distributed training"""
compile
:
bool
=
True
"""whether to use torch.compile to compile the model and functions"""
compile_mode
:
Optional
[
str
]
=
None
"""the mode to use for torch.compile"""
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
None
...
...
@@ -118,6 +123,8 @@ class Args:
"""the probability of logging"""
port
:
int
=
12356
"""the port to use for distributed training"""
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
# to be filled in runtime
local_batch_size
:
int
=
0
...
...
@@ -197,7 +204,7 @@ def run(local_rank, world_size):
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
play_mode
=
'self'
,
)
envs
.
num_envs
=
args
.
local_num_envs
obs_space
=
envs
.
observation_space
...
...
@@ -205,7 +212,25 @@ def run(local_rank, world_size):
if
local_rank
==
0
:
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
eval_envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
local_eval_num_envs
,
num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
),
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
args
.
play_mode
,
)
eval_envs
.
num_envs
=
local_eval_num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
if
args
.
embedding_file
:
embeddings
=
np
.
load
(
args
.
embedding_file
)
...
...
@@ -213,11 +238,14 @@ def run(local_rank, world_size):
else
:
embedding_shape
=
None
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent
1
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
if
args
.
embedding_file
:
agent
.
load_embeddings
(
embeddings
)
agent1
.
load_embeddings
(
embeddings
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent2
.
load_state_dict
(
agent1
.
state_dict
())
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
optim_params
=
list
(
agent1
.
parameters
())
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
...
...
@@ -225,9 +253,21 @@ def run(local_rank, world_size):
x
=
x
.
masked_fill
(
~
valid
,
0
)
return
x
.
sum
()
/
valid
.
float
()
.
sum
()
def
train_step
(
agent
,
scaler
,
mb_obs
,
mb_actions
,
mb_logprobs
,
mb_advantages
,
mb_returns
,
mb_values
):
def
masked_normalize
(
x
,
valid
,
eps
=
1e-8
):
x
=
x
.
masked_fill
(
~
valid
,
0
)
n
=
valid
.
float
()
.
sum
()
mean
=
x
.
sum
()
/
n
var
=
((
x
-
mean
)
**
2
)
.
sum
()
/
n
std
=
(
var
+
eps
)
.
sqrt
()
return
(
x
-
mean
)
/
std
def
train_step
(
agent
:
Agent
,
scaler
,
mb_obs
,
mb_actions
,
mb_logprobs
,
mb_advantages
,
mb_returns
,
mb_values
,
mb_learns
):
with
autocast
(
enabled
=
args
.
fp16_train
):
_
,
newlogprob
,
entropy
,
newvalue
,
valid
=
agent
.
get_action_and_value
(
mb_obs
,
mb_actions
.
long
())
logits
,
newvalue
,
valid
=
agent
(
mb_obs
)
probs
=
Categorical
(
logits
=
logits
)
newlogprob
=
probs
.
log_prob
(
mb_actions
)
entropy
=
probs
.
entropy
()
valid
=
torch
.
logical_and
(
valid
,
mb_learns
)
logratio
=
newlogprob
-
mb_logprobs
ratio
=
logratio
.
exp
()
...
...
@@ -238,7 +278,7 @@ def run(local_rank, world_size):
clipfrac
=
((
ratio
-
1.0
)
.
abs
()
>
args
.
clip_coef
)
.
float
()
.
mean
()
if
args
.
norm_adv
:
mb_advantages
=
(
mb_advantages
-
mb_advantages
.
mean
())
/
(
mb_advantages
.
std
()
+
1e-8
)
mb_advantages
=
masked_normalize
(
mb_advantages
,
valid
,
eps
=
1e-8
)
# Policy loss
pg_loss1
=
-
mb_advantages
*
ratio
...
...
@@ -269,15 +309,25 @@ def run(local_rank, world_size):
scaler
.
unscale_
(
optimizer
)
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
def
predict_step
(
agent
,
next_obs
):
def
predict_step
(
agent
1
:
Agent
,
agent2
:
Agent
,
next_obs
,
learn
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
values
=
agent
(
next_obs
)
return
logits
,
values
logits1
,
value1
,
valid
=
agent1
(
next_obs
)
logits2
,
value2
,
valid
=
agent2
(
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits1
,
logits2
)
value
=
torch
.
where
(
learn
[:,
None
],
value1
,
value2
)
return
logits
,
value
def
eval_step
(
agent
:
Agent
,
next_obs
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
=
agent
.
get_logit
(
next_obs
)
return
logits
if
args
.
compile
:
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile_mode
)
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
args
.
compile_mode
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
'default'
)
# eval_step = torch.compile(eval_step, mode=args.compile)
def
to_tensor
(
x
,
dtype
=
torch
.
float32
):
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
...
...
@@ -287,12 +337,12 @@ def run(local_rank, world_size):
actions
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
,
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
))
.
to
(
device
)
to_plays
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
)
)
.
to
(
device
)
avg_ep_returns
=
[]
avg_win_rates
=
[]
avg_sp_win_rates
=
[]
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
# TRY NOT TO MODIFY: start the game
global_step
=
0
...
...
@@ -300,8 +350,16 @@ def run(local_rank, world_size):
start_time
=
time
.
time
()
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
dtype
=
torch
.
uint8
)
next_to_play
=
to_tensor
(
info
[
"to_play"
])
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
ai_player_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
ai_player_
)
ai_player
=
to_tensor
(
ai_player_
,
dtype
=
next_to_play
.
dtype
)
next_value
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
...
...
@@ -319,47 +377,44 @@ def run(local_rank, world_size):
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
to_plays
[
step
]
=
next_to_play
learn
=
next_to_play
==
ai_player
learns
[
step
]
=
learn
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
agent
,
next_obs
)
logits
,
value
=
predict_step
(
agent1
,
agent2
,
next_obs
,
learn
)
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
logprob
=
probs
.
log_prob
(
action
)
values
[
step
]
=
value
.
flatten
()
values
[
step
]
=
value
actions
[
step
]
=
action
logprobs
[
step
]
=
logprob
action
=
action
.
cpu
()
.
numpy
()
model_time
+=
time
.
time
()
-
_start
next_value
=
torch
.
where
(
learn
,
value
,
next_value
)
*
(
1
-
next_done
.
float
())
_start
=
time
.
time
()
to_play
=
next_to_play_
next_obs
,
reward
,
next_done_
,
info
=
envs
.
step
(
action
)
next_to_play
=
to_tensor
(
info
[
"to_play"
])
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
)
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
torch
.
uint8
),
to_tensor
(
next_done_
)
collect_time
=
time
.
time
()
-
collect_start
print
(
f
"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
,
flush
=
True
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
torch
.
uint8
),
to_tensor
(
next_done_
,
torch
.
bool
)
if
not
writer
:
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
pl
=
1
if
to_play
[
idx
]
==
ai_player_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
if
info
[
'is_selfplay'
][
idx
]:
# win rate for the first player
pl
=
1
if
next_to_play
[
idx
]
==
0
else
-
1
winner
=
0
if
episode_reward
*
pl
>
0
else
1
avg_sp_win_rates
.
append
(
1
-
winner
)
else
:
# win rate of agent
winner
=
0
if
episode_reward
>
0
else
1
avg_win_rates
.
append
(
1
-
winner
)
avg_win_rates
.
append
(
win
)
if
random
.
random
()
<
args
.
log_p
:
n
=
100
...
...
@@ -368,37 +423,62 @@ def run(local_rank, world_size):
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
print
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
len
(
avg_ep_returns
)
>
n
:
if
random
.
random
()
<
1
/
n
:
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
avg_ep_returns
=
[]
if
len
(
avg_win_rates
)
>
n
:
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
avg_win_rates
=
[]
if
len
(
avg_sp_win_rates
)
>
n
:
writer
.
add_scalar
(
"charts/avg_sp_win_rate"
,
np
.
mean
(
avg_sp_win_rates
),
global_step
)
avg_sp_win_rates
=
[]
collect_time
=
time
.
time
()
-
collect_start
print
(
f
"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
,
flush
=
True
)
# bootstrap value if not done
with
torch
.
no_grad
():
next_value
=
agent
.
get_value
(
next_obs
)
.
reshape
(
1
,
-
1
)
value
=
agent1
.
get_value
(
next_obs
)
.
reshape
(
-
1
)
advantages
=
torch
.
zeros_like
(
rewards
)
.
to
(
device
)
nextvalues
=
torch
.
where
(
next_to_play
==
ai_player
,
value
,
next_value
)
done_used
=
torch
.
zeros_like
(
next_done
,
dtype
=
torch
.
bool
)
reward
=
0
lastgaelam
=
0
next_to_play_
=
next_to_play
for
t
in
reversed
(
range
(
args
.
num_steps
)):
to_play
=
to_plays
[
t
]
if
t
==
args
.
num_steps
-
1
:
nextnonterminal
=
1.0
-
next_done
nextvalues
=
next_value
else
:
nextnonterminal
=
1.0
-
dones
[
t
+
1
]
nextvalues
=
values
[
t
+
1
]
sp
=
2.0
*
(
to_play
==
next_to_play_
)
.
float
()
-
1.0
delta
=
rewards
[
t
]
+
args
.
gamma
*
nextvalues
*
sp
*
nextnonterminal
-
values
[
t
]
lastgaelam
=
delta
+
args
.
gamma
*
args
.
gae_lambda
*
nextnonterminal
*
lastgaelam
# TODO: experiment with it
# lastgaelam = lastgaelam * sp
advantages
[
t
]
=
lastgaelam
next_to_play_
=
to_play
# if learns[t]:
# if dones[t+1]:
# reward = rewards[t]
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# if not done_used:
# reward = reward
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn
=
learns
[
t
]
if
t
!=
args
.
num_steps
-
1
:
next_done
=
dones
[
t
+
1
]
sp
=
2
*
(
learn
.
int
()
-
0.5
)
reward
=
torch
.
where
(
next_done
,
rewards
[
t
]
*
sp
,
torch
.
where
(
learn
&
done_used
,
0
,
reward
))
real_done
=
next_done
|
~
done_used
nextvalues
=
torch
.
where
(
real_done
,
0
,
nextvalues
)
lastgaelam
=
torch
.
where
(
real_done
,
0
,
lastgaelam
)
done_used
=
torch
.
where
(
next_done
,
learn
,
torch
.
where
(
learn
&
~
done_used
,
True
,
done_used
))
delta
=
reward
+
args
.
gamma
*
nextvalues
-
values
[
t
]
advantages
[
t
]
=
lastgaelam_
=
delta
+
args
.
gamma
*
args
.
gae_lambda
*
lastgaelam
nextvalues
=
torch
.
where
(
learn
,
values
[
t
],
nextvalues
)
lastgaelam
=
torch
.
where
(
learn
,
lastgaelam_
,
lastgaelam
)
returns
=
advantages
+
values
_start
=
time
.
time
()
...
...
@@ -412,6 +492,7 @@ def run(local_rank, world_size):
b_advantages
=
advantages
.
reshape
(
-
1
)
b_returns
=
returns
.
reshape
(
-
1
)
b_values
=
values
.
reshape
(
-
1
)
b_learns
=
learns
.
reshape
(
-
1
)
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
...
...
@@ -425,10 +506,10 @@ def run(local_rank, world_size):
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
}
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
train_step
(
agent
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
])
reduce_gradidents
(
agent
,
args
.
world_size
)
nn
.
utils
.
clip_grad_norm_
(
agent
.
parameters
()
,
args
.
max_grad_norm
)
train_step
(
agent
1
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
]
,
b_learns
[
mb_inds
]
)
reduce_gradidents
(
optim_params
,
args
.
world_size
)
nn
.
utils
.
clip_grad_norm_
(
optim_params
,
args
.
max_grad_norm
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
...
...
@@ -448,8 +529,8 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
or
iteration
==
args
.
num_iterations
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pth"
))
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pth"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
...
@@ -471,10 +552,69 @@ def run(local_rank, world_size):
print
(
"SPS:"
,
SPS
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
local_rank
==
0
:
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
else
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
should_update
=
should_update
.
item
()
>
0
if
should_update
:
agent2
.
load_state_dict
(
agent1
.
state_dict
())
version
+=
1
if
local_rank
==
0
:
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pth"
))
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
avg_ep_returns
.
clear
()
_start
=
time
.
time
()
episode_lengths
=
[]
episode_rewards
=
[]
eval_win_rates
=
[]
e_obs
=
eval_envs
.
reset
()[
0
]
while
True
:
e_obs
=
to_tensor
(
e_obs
,
dtype
=
torch
.
uint8
)
e_logits
=
eval_step
(
agent1
,
e_obs
)
e_probs
=
torch
.
softmax
(
e_logits
,
dim
=-
1
)
e_probs
=
e_probs
.
cpu
()
.
numpy
()
e_actions
=
e_probs
.
argmax
(
axis
=
1
)
e_obs
,
e_rewards
,
e_dones
,
e_info
=
eval_envs
.
step
(
e_actions
)
for
idx
,
d
in
enumerate
(
e_dones
):
if
d
:
episode_length
=
e_info
[
'l'
][
idx
]
episode_reward
=
e_info
[
'r'
][
idx
]
win
=
1
if
episode_reward
>
0
else
0
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
eval_win_rates
.
append
(
win
)
if
len
(
episode_lengths
)
>=
local_eval_episodes
:
break
eval_return
=
np
.
mean
(
episode_rewards
[:
local_eval_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
local_eval_episodes
])
eval_win_rate
=
np
.
mean
(
eval_win_rates
[:
local_eval_episodes
])
eval_stats
=
torch
.
tensor
([
eval_return
,
eval_ep_len
,
eval_win_rate
],
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
if
local_rank
==
0
:
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
"charts/eval_ep_len"
,
eval_ep_len
,
global_step
)
writer
.
add_scalar
(
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return}, eval_ep_len={eval_ep_len}, eval_win_rate={eval_win_rate}"
)
if
args
.
world_size
>
1
:
dist
.
destroy_process_group
()
envs
.
close
()
if
local_rank
==
0
:
torch
.
save
(
agent1
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pth"
))
writer
.
close
()
...
...
ygoai/rl/agent.py
View file @
385bd1cb
...
...
@@ -320,66 +320,88 @@ class Encoder(nn.Module):
return
f_actions
,
f_state
,
mask
,
valid
class
PPOCritic
(
nn
.
Module
):
#
class PPOCritic(nn.Module):
def
__init__
(
self
,
channels
):
super
(
PPOCritic
,
self
)
.
__init__
()
c
=
channels
#
def __init__(self, channels):
#
super(PPOCritic, self).__init__()
#
c = channels
self
.
net
=
nn
.
Sequential
(
nn
.
Linear
(
c
*
2
,
c
//
2
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
2
,
1
),
)
# self.net = nn.Sequential(
# nn.Linear(c * 2, c // 2),
# nn.ReLU(),
# nn.Linear(c // 2, 1),
# )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
def
forward
(
self
,
f_state
):
return
self
.
net
(
f_state
)
# def forward(self, f_actions, mask, action):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
class
PPOActor
(
nn
.
Module
):
# def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
def
__init__
(
self
,
channels
):
super
(
PPOActor
,
self
)
.
__init__
()
class
Actor
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_transformer
=
False
):
super
(
Actor
,
self
)
.
__init__
()
c
=
channels
self
.
trans
=
nn
.
TransformerEncoderLayer
(
c
,
4
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
self
.
use_transformer
=
use_transformer
if
use_transformer
:
self
.
transformer
=
nn
.
TransformerEncoderLayer
(
c
,
4
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
self
.
head
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
def
forward
(
self
,
f_actions
,
mask
,
action
):
f_actions
=
self
.
trans
(
f_actions
,
src_key_padding_mask
=
mask
)
logits
=
self
.
head
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
probs
=
Categorical
(
logits
=
logits
)
return
probs
.
log_prob
(
action
),
probs
.
entropy
()
def
predict
(
self
,
f_actions
,
mask
):
f_actions
=
self
.
trans
(
f_actions
,
src_key_padding_mask
=
mask
)
def
forward
(
self
,
f_actions
,
mask
):
if
self
.
use_transformer
:
f_actions
=
self
.
transformer
(
f_actions
,
src_key_padding_mask
=
mask
)
logits
=
self
.
head
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
return
logits
class
PPOAgent
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
,
a_trans
=
True
):
super
(
PPOAgent
,
self
)
.
__init__
()
self
.
encoder
=
Encoder
(
channels
,
num_card_layers
,
num_action_layers
,
num_history_action_layers
,
embedding_shape
,
bias
,
affine
)
c
=
channels
self
.
actor
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
self
.
actor
=
Actor
(
c
,
a_trans
)
self
.
critic
=
nn
.
Sequential
(
nn
.
Linear
(
c
*
2
,
c
//
2
),
...
...
@@ -390,24 +412,15 @@ class PPOAgent(nn.Module):
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
def
get_
value
(
self
,
x
):
def
get_
logit
(
self
,
x
):
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
return
self
.
critic
(
f_state
)
return
self
.
actor
(
f_actions
,
mask
)
def
get_
action_and_value
(
self
,
x
,
action
):
def
get_
value
(
self
,
x
):
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
probs
=
Categorical
(
logits
=
logits
)
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f_state
),
valid
return
self
.
critic
(
f_state
)
def
forward
(
self
,
x
):
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
return
logits
,
self
.
critic
(
f_state
)
logits
=
self
.
actor
(
f_actions
,
mask
)
return
logits
,
self
.
critic
(
f_state
),
valid
ygoenv/ygoenv/ygopro/ygopro.h
View file @
385bd1cb
...
...
@@ -2935,7 +2935,6 @@ private:
return
;
}
auto
player
=
read_u8
();
to_play_
=
player
;
auto
size
=
read_u8
();
std
::
vector
<
Card
>
cards
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
...
...
@@ -3315,7 +3314,6 @@ private:
throw
std
::
runtime_error
(
"Retry"
);
}
else
if
(
msg_
==
MSG_SELECT_BATTLECMD
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
activatable
=
read_cardlist_spec
(
true
);
auto
attackable
=
read_cardlist_spec
(
true
,
true
);
bool
to_m2
=
read_u8
();
...
...
@@ -3366,6 +3364,7 @@ private:
}
int
n_activatables
=
activatable
.
size
();
int
n_attackables
=
attackable
.
size
();
to_play_
=
player
;
callback_
=
[
this
,
n_activatables
,
n_attackables
,
to_ep
,
to_m2
](
int
idx
)
{
if
(
idx
<
n_activatables
)
{
OCG_SetResponsei
(
pduel_
,
idx
<<
16
);
...
...
@@ -3382,7 +3381,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_UNSELECT_CARD
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
bool
finishable
=
read_u8
();
bool
cancelable
=
read_u8
();
auto
min
=
read_u8
();
...
...
@@ -3435,6 +3433,7 @@ private:
// cancelable and finishable not needed
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
options_
[
idx
]
==
"f"
)
{
OCG_SetResponsei
(
pduel_
,
-
1
);
...
...
@@ -3447,7 +3446,6 @@ private:
}
else
if
(
msg_
==
MSG_SELECT_CARD
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
bool
cancelable
=
read_u8
();
auto
min
=
read_u8
();
auto
max
=
read_u8
();
...
...
@@ -3535,6 +3533,7 @@ private:
}
}
to_play_
=
player
;
callback_
=
[
this
,
combs
](
int
idx
)
{
const
auto
&
comb
=
combs
[
idx
];
resp_buf_
[
0
]
=
comb
.
size
();
...
...
@@ -3545,7 +3544,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_TRIBUTE
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
bool
cancelable
=
read_u8
();
auto
min
=
read_u8
();
auto
max
=
read_u8
();
...
...
@@ -3621,6 +3619,7 @@ private:
options_
.
push_back
(
option
);
}
to_play_
=
player
;
callback_
=
[
this
,
combs
](
int
idx
)
{
const
auto
&
comb
=
combs
[
idx
];
resp_buf_
[
0
]
=
comb
.
size
();
...
...
@@ -3632,7 +3631,6 @@ private:
}
else
if
(
msg_
==
MSG_SELECT_SUM
)
{
auto
mode
=
read_u8
();
auto
player
=
read_u8
();
to_play_
=
player
;
auto
val
=
read_u32
();
auto
min
=
read_u8
();
auto
max
=
read_u8
();
...
...
@@ -3761,6 +3759,7 @@ private:
options_
.
push_back
(
option
);
}
to_play_
=
player
;
callback_
=
[
this
,
combs
,
must_select_size
](
int
idx
)
{
const
auto
&
comb
=
combs
[
idx
];
resp_buf_
[
0
]
=
must_select_size
+
comb
.
size
();
...
...
@@ -3775,7 +3774,6 @@ private:
}
else
if
(
msg_
==
MSG_SELECT_CHAIN
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
size
=
read_u8
();
auto
spe_count
=
read_u8
();
bool
forced
=
read_u8
();
...
...
@@ -3872,6 +3870,7 @@ private:
if
(
!
forced
)
{
options_
.
push_back
(
"c"
);
}
to_play_
=
player
;
callback_
=
[
this
,
forced
](
int
idx
)
{
const
auto
&
option
=
options_
[
idx
];
if
((
option
==
"c"
)
&&
(
!
forced
))
{
...
...
@@ -3882,7 +3881,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_YESNO
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
if
(
verbose_
)
{
auto
desc
=
read_u32
();
...
...
@@ -3907,6 +3905,7 @@ private:
dp_
+=
4
;
}
options_
=
{
"y"
,
"n"
};
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
OCG_SetResponsei
(
pduel_
,
1
);
...
...
@@ -3918,7 +3917,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
std
::
string
spec
;
if
(
verbose_
)
{
...
...
@@ -3981,6 +3979,7 @@ private:
spec
=
ls_to_spec
(
loc
,
seq
,
pos
,
c
!=
player
);
}
options_
=
{
"y "
+
spec
,
"n "
+
spec
};
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
OCG_SetResponsei
(
pduel_
,
1
);
...
...
@@ -3992,7 +3991,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_OPTION
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
size
=
read_u8
();
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
...
...
@@ -4016,6 +4014,7 @@ private:
options_
.
push_back
(
std
::
to_string
(
i
+
1
));
}
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
verbose_
)
{
players_
[
to_play_
]
->
notify
(
"You selected option "
+
options_
[
idx
]
+
...
...
@@ -4029,7 +4028,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
int32_t
player
=
read_u8
();
to_play_
=
player
;
auto
summonable_
=
read_cardlist_spec
();
auto
spsummon_
=
read_cardlist_spec
();
auto
repos_
=
read_cardlist_spec
();
...
...
@@ -4134,6 +4132,7 @@ private:
}
}
to_play_
=
player
;
callback_
=
[
this
,
spsummon_offset
,
repos_offset
,
mset_offset
,
set_offset
,
activate_offset
](
int
idx
)
{
const
auto
&
option
=
options_
[
idx
];
...
...
@@ -4169,7 +4168,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_PLACE
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
count
=
read_u8
();
if
(
count
==
0
)
{
count
=
1
;
...
...
@@ -4189,6 +4187,7 @@ private:
" places for card, from "
+
specs_str
+
"."
);
}
}
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
std
::
string
spec
=
options_
[
idx
];
...
...
@@ -4205,7 +4204,6 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_DISFIELD
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
count
=
read_u8
();
if
(
count
==
0
)
{
count
=
1
;
...
...
@@ -4225,6 +4223,7 @@ private:
std
::
to_string
(
count
)
+
" not implemented"
);
}
}
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
std
::
string
spec
=
options_
[
idx
];
...
...
@@ -4241,7 +4240,6 @@ private:
};
}
else
if
(
msg_
==
MSG_ANNOUNCE_NUMBER
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
count
=
read_u8
();
std
::
vector
<
int
>
numbers
;
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
...
...
@@ -4265,12 +4263,12 @@ private:
str
+=
"]"
;
pl
->
notify
(
str
);
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
OCG_SetResponsei
(
pduel_
,
idx
);
};
}
else
if
(
msg_
==
MSG_ANNOUNCE_ATTRIB
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
count
=
read_u8
();
auto
flag
=
read_u32
();
...
...
@@ -4310,6 +4308,7 @@ private:
options_
.
push_back
(
option
);
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
const
auto
&
option
=
options_
[
idx
];
uint32_t
resp
=
0
;
...
...
@@ -4323,7 +4322,6 @@ private:
}
else
if
(
msg_
==
MSG_SELECT_POSITION
)
{
auto
player
=
read_u8
();
to_play_
=
player
;
auto
code
=
read_u32
();
auto
valid_pos
=
read_u8
();
...
...
@@ -4348,6 +4346,7 @@ private:
i
++
;
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
uint8_t
pos
=
options_
[
idx
][
0
]
-
'1'
;
OCG_SetResponsei
(
pduel_
,
1
<<
pos
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment