Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
389fcd9c
Commit
389fcd9c
authored
Mar 11, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fixing set_obs_cards seg fault
parent
46b4b5ae
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
122 additions
and
83 deletions
+122
-83
Makefile
Makefile
+1
-1
README.md
README.md
+1
-1
assets/deck/CenturIon2.ydk
assets/deck/CenturIon2.ydk
+2
-1
assets/deck/Floowandereeze.ydk
assets/deck/Floowandereeze.ydk
+1
-1
assets/deck/Floowandereeze2.ydk
assets/deck/Floowandereeze2.ydk
+1
-1
assets/deck/Labrynth.ydk
assets/deck/Labrynth.ydk
+1
-1
assets/deck/Pachycephalo.ydk
assets/deck/Pachycephalo.ydk
+1
-1
assets/deck/SnakeEyeAlter.ydk
assets/deck/SnakeEyeAlter.ydk
+2
-1
scripts/code_list.txt
scripts/code_list.txt
+1
-0
scripts/ppo2.py
scripts/ppo2.py
+18
-55
ygoai/rl/agent.py
ygoai/rl/agent.py
+27
-1
ygoai/rl/ppo.py
ygoai/rl/ppo.py
+0
-1
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+66
-18
No files found.
Makefile
View file @
389fcd9c
SCRIPTS_REPO
:=
"https://github.com/
Fluorohydride
/ygopro-scripts.git"
SCRIPTS_REPO
:=
"https://github.com/
mycard
/ygopro-scripts.git"
SCRIPTS_DIR
:=
"../ygopro-scripts"
SCRIPTS_DIR
:=
"../ygopro-scripts"
DATABASE_REPO
:=
"https://github.com/mycard/ygopro-database/raw/master/locales"
DATABASE_REPO
:=
"https://github.com/mycard/ygopro-database/raw/master/locales"
LOCALES
:=
en zh
LOCALES
:=
en zh
...
...
README.md
View file @
389fcd9c
...
@@ -11,7 +11,7 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
...
@@ -11,7 +11,7 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Building
## Building
###
p
rerequisites
###
P
rerequisites
-
gcc 10+ or clang 11+
-
gcc 10+ or clang 11+
-
[
xmake
](
https://xmake.io/#/getting_started
)
-
[
xmake
](
https://xmake.io/#/getting_started
)
-
PyTorch 2.0 or later with cuda support
-
PyTorch 2.0 or later with cuda support
...
...
assets/deck/CenturIon2.ydk
View file @
389fcd9c
...
@@ -32,7 +32,7 @@
...
@@ -32,7 +32,7 @@
35059553
35059553
24224830
24224830
24224830
24224830
84211599
97268402
73628505
73628505
40155014
40155014
40155014
40155014
...
@@ -58,3 +58,4 @@
...
@@ -58,3 +58,4 @@
02857636
02857636
!side
!side
27204312
27204312
92907249
assets/deck/Floowandereeze.ydk
View file @
389fcd9c
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
51697825
51697825
51697825
51697825
51697825
51697825
84211599
98645731
28126717
28126717
55521751
55521751
24224830
24224830
...
...
assets/deck/Floowandereeze2.ydk
View file @
389fcd9c
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
80433039
80433039
17827173
17827173
24508238
24508238
84211599
98645731
75500286
75500286
98645731
98645731
49238328
49238328
...
...
assets/deck/Labrynth.ydk
View file @
389fcd9c
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
97268402
97268402
2511
2511
2511
2511
84211599
97268402
24224830
24224830
24224830
24224830
33407125
33407125
...
...
assets/deck/Pachycephalo.ydk
View file @
389fcd9c
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
1984618
1984618
1984618
1984618
35261759
35261759
84211599
49238328
49238328
49238328
84797028
84797028
84797028
84797028
...
...
assets/deck/SnakeEyeAlter.ydk
View file @
389fcd9c
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
24508238
24508238
2295440
2295440
51405049
51405049
84211599
27204311
89023486
89023486
24224830
24224830
24224830
24224830
...
@@ -58,3 +58,4 @@
...
@@ -58,3 +58,4 @@
94259633
94259633
!side
!side
52340445
52340445
27204312
scripts/code_list.txt
View file @
389fcd9c
...
@@ -840,3 +840,4 @@
...
@@ -840,3 +840,4 @@
35809262
35809262
92731385
92731385
74018812
74018812
92907249
scripts/ppo
_t
2.py
→
scripts/ppo2.py
View file @
389fcd9c
...
@@ -5,7 +5,6 @@ from collections import deque
...
@@ -5,7 +5,6 @@ from collections import deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Optional
from
typing
import
Literal
,
Optional
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
import
tyro
import
tyro
...
@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
...
@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_self
from
ygoai.rl.ppo
import
bootstrap_value_self
play
from
ygoai.rl.eval
import
evaluate
from
ygoai.rl.eval
import
evaluate
...
@@ -79,11 +78,6 @@ class Args:
...
@@ -79,11 +78,6 @@ class Args:
gae_lambda
:
float
=
0.95
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
update_win_rate
:
float
=
0.55
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
"""the required return to update the agent"""
minibatch_size
:
int
=
256
minibatch_size
:
int
=
256
"""the mini-batch size"""
"""the mini-batch size"""
update_epochs
:
int
=
2
update_epochs
:
int
=
2
...
@@ -265,28 +259,21 @@ def main():
...
@@ -265,28 +259,21 @@ def main():
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
def
predict_step
(
agent
:
Agent
,
next_obs
):
agent_t
.
eval
()
agent_t
.
load_state_dict
(
agent
.
state_dict
())
def
predict_step
(
agent
:
Agent
,
agent_t
:
Agent
,
next_obs
,
learn
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
value
,
valid
=
agent
(
next_obs
)
logits
,
value
,
valid
=
agent
(
next_obs
)
logits_t
,
value_t
,
valid
=
agent_t
(
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
return
logits
,
value
return
logits
,
value
from
ygoai.rl.ppo
import
train_step
from
ygoai.rl.ppo
import
train_step
if
args
.
compile
:
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
# predict_step = torch.compile(predict_step, mode=args.compile)
agent
=
torch
.
compile
(
agent
,
mode
=
args
.
compile
)
obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
...
@@ -298,7 +285,6 @@ def main():
...
@@ -298,7 +285,6 @@ def main():
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
num_steps
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step
=
0
global_step
=
0
...
@@ -315,7 +301,7 @@ def main():
...
@@ -315,7 +301,7 @@ def main():
])
])
np
.
random
.
shuffle
(
ai_player1_
)
np
.
random
.
shuffle
(
ai_player1_
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
ai_player1
=
to_tensor
(
ai_player1_
,
device
,
dtype
=
next_to_play
.
dtype
)
next_value
=
0
next_value
1
=
next_value2
=
0
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
for
iteration
in
range
(
1
,
args
.
num_iterations
+
1
):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
...
@@ -324,8 +310,6 @@ def main():
...
@@ -324,8 +310,6 @@ def main():
lrnow
=
frac
*
args
.
learning_rate
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
agent
.
eval
()
model_time
=
0
model_time
=
0
env_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
collect_start
=
time
.
time
()
...
@@ -339,7 +323,7 @@ def main():
...
@@ -339,7 +323,7 @@ def main():
learns
[
step
]
=
learn
learns
[
step
]
=
learn
_start
=
time
.
time
()
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
agent
,
traced_model_t
,
next_obs
,
learn
)
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
value
=
value
.
flatten
()
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
action
=
probs
.
sample
()
...
@@ -352,7 +336,8 @@ def main():
...
@@ -352,7 +336,8 @@ def main():
model_time
+=
time
.
time
()
-
_start
model_time
+=
time
.
time
()
-
_start
next_nonterminal
=
1
-
next_done
.
float
()
next_nonterminal
=
1
-
next_done
.
float
()
next_value
=
torch
.
where
(
learn
,
value
,
next_value
)
*
next_nonterminal
next_value1
=
torch
.
where
(
learn
,
value
,
next_value1
)
*
next_nonterminal
next_value2
=
torch
.
where
(
learn
,
next_value2
,
value
)
*
next_nonterminal
_start
=
time
.
time
()
_start
=
time
.
time
()
to_play
=
next_to_play_
to_play
=
next_to_play_
...
@@ -393,16 +378,14 @@ def main():
...
@@ -393,16 +378,14 @@ def main():
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
with
torch
.
no_grad
():
value
=
agent
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
traced_model
(
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
traced_model_t
(
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value
)
advantages
=
bootstrap_value_selfplay
(
advantages
=
bootstrap_value_self
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
values
,
rewards
,
dones
,
learns
,
nextvalues
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
returns
=
advantages
+
values
returns
=
advantages
+
values
bootstrap_time
=
time
.
time
()
-
_start
bootstrap_time
=
time
.
time
()
-
_start
agent
.
train
()
_start
=
time
.
time
()
_start
=
time
.
time
()
# flatten the batch
# flatten the batch
b_obs
=
{
b_obs
=
{
...
@@ -475,31 +458,11 @@ def main():
...
@@ -475,31 +458,11 @@ def main():
if
rank
==
0
:
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
rank
==
0
:
if
iteration
%
args
.
eval_interval
==
0
:
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
# Eval with rule-based policy
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
else
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
should_update
=
should_update
.
item
()
>
0
if
should_update
:
agent_t
.
load_state_dict
(
agent
.
state_dict
())
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
version
+=
1
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
avg_ep_returns
.
clear
()
_start
=
time
.
time
()
_start
=
time
.
time
()
agent
.
eval
()
eval_return
=
evaluate
(
eval_return
=
evaluate
(
eval_envs
,
agent
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
# sync the statistics
...
...
ygoai/rl/agent.py
View file @
389fcd9c
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -17,6 +19,29 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
...
@@ -17,6 +19,29 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
intervals
=
torch
.
cat
([
points
[
0
:
1
],
points
[
1
:]
-
points
[:
-
1
]],
dim
=
0
)
intervals
=
torch
.
cat
([
points
[
0
:
1
],
points
[
1
:]
-
points
[:
-
1
]],
dim
=
0
)
return
points
,
intervals
return
points
,
intervals
class
PositionalEncoding
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
dropout
:
float
=
0.0
,
max_len
:
int
=
5000
):
super
()
.
__init__
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
position
=
torch
.
arange
(
max_len
)
.
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
d_model
,
2
)
*
(
-
math
.
log
(
10000.0
)
/
d_model
))
pe
=
torch
.
zeros
(
max_len
,
1
,
d_model
)
pe
[:,
0
,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
pe
[:,
0
,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
self
.
register_buffer
(
'pe'
,
pe
)
def
forward
(
self
,
x
):
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x
=
x
+
self
.
pe
[:
x
.
size
(
0
)]
return
self
.
dropout
(
x
)
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
...
@@ -122,7 +147,6 @@ class Encoder(nn.Module):
...
@@ -122,7 +147,6 @@ class Encoder(nn.Module):
nn
.
Linear
(
c
,
c
),
nn
.
Linear
(
c
,
c
),
)
)
self
.
h_id_fc_emb
=
linear
(
1024
,
c
)
self
.
h_id_fc_emb
=
linear
(
1024
,
c
)
self
.
h_id_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
h_id_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
h_a_feat_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
h_a_feat_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
...
@@ -134,6 +158,7 @@ class Encoder(nn.Module):
...
@@ -134,6 +158,7 @@ class Encoder(nn.Module):
for
i
in
range
(
num_action_layers
)
for
i
in
range
(
num_action_layers
)
])
])
self
.
action_history_pe
=
PositionalEncoding
(
c
,
dropout
=
0.0
)
self
.
action_history_net
=
nn
.
ModuleList
([
self
.
action_history_net
=
nn
.
ModuleList
([
nn
.
TransformerDecoderLayer
(
nn
.
TransformerDecoderLayer
(
c
,
num_heads
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
c
,
num_heads
,
c
*
4
,
dropout
=
0.0
,
batch_first
=
True
,
norm_first
=
True
,
bias
=
False
)
...
@@ -322,6 +347,7 @@ class Encoder(nn.Module):
...
@@ -322,6 +347,7 @@ class Encoder(nn.Module):
x_h_a_feats
=
self
.
encode_action_
(
x_h_actions
[:,
:,
mo
:])
x_h_a_feats
=
self
.
encode_action_
(
x_h_actions
[:,
:,
mo
:])
x_h_a_feats
=
torch
.
cat
(
x_h_a_feats
,
dim
=-
1
)
x_h_a_feats
=
torch
.
cat
(
x_h_a_feats
,
dim
=-
1
)
f_h_actions
=
self
.
h_id_norm
(
x_h_id
)
+
self
.
h_a_feat_norm
(
x_h_a_feats
)
f_h_actions
=
self
.
h_id_norm
(
x_h_id
)
+
self
.
h_a_feat_norm
(
x_h_a_feats
)
f_h_actions
=
self
.
action_history_pe
(
f_h_actions
)
for
layer
in
self
.
action_history_net
:
for
layer
in
self
.
action_history_net
:
f_actions
=
layer
(
f_actions
,
f_h_actions
)
f_actions
=
layer
(
f_actions
,
f_h_actions
)
...
...
ygoai/rl/ppo.py
View file @
389fcd9c
...
@@ -108,7 +108,6 @@ def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done,
...
@@ -108,7 +108,6 @@ def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done,
def
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
gamma
,
gae_lambda
):
def
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
gamma
,
gae_lambda
):
# TODO: drop epsilon steps for estimated nextvalues
num_steps
=
rewards
.
size
(
0
)
num_steps
=
rewards
.
size
(
0
)
advantages
=
torch
.
zeros_like
(
rewards
)
advantages
=
torch
.
zeros_like
(
rewards
)
# TODO: optimize this
# TODO: optimize this
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
389fcd9c
...
@@ -365,8 +365,7 @@ inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
...
@@ -365,8 +365,7 @@ inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
return
spec
;
return
spec
;
}
}
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos,
inline
std
::
string
ls_to_spec
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
,
bool
opponent
)
{
bool opponent) {
std
::
string
spec
=
ls_to_spec
(
loc
,
seq
,
pos
);
std
::
string
spec
=
ls_to_spec
(
loc
,
seq
,
pos
);
if
(
opponent
)
{
if
(
opponent
)
{
spec
.
insert
(
0
,
1
,
'o'
);
spec
.
insert
(
0
,
1
,
'o'
);
...
@@ -1471,6 +1470,7 @@ public:
...
@@ -1471,6 +1470,7 @@ public:
int
init_lp
=
8000
;
int
init_lp
=
8000
;
int
startcount
=
5
;
int
startcount
=
5
;
int
drawcount
=
1
;
int
drawcount
=
1
;
for
(
PlayerId
i
=
0
;
i
<
2
;
i
++
)
{
for
(
PlayerId
i
=
0
;
i
<
2
;
i
++
)
{
if
(
players_
[
i
]
!=
nullptr
)
{
if
(
players_
[
i
]
!=
nullptr
)
{
delete
players_
[
i
];
delete
players_
[
i
];
...
@@ -1641,10 +1641,12 @@ public:
...
@@ -1641,10 +1641,12 @@ public:
void
Step
(
const
Action
&
action
)
override
{
void
Step
(
const
Action
&
action
)
override
{
// clock_t start = clock();
// clock_t start = clock();
// fmt::println("Step");
int
idx
=
action
[
"action"
_
];
int
idx
=
action
[
"action"
_
];
callback_
(
idx
);
callback_
(
idx
);
update_history_actions(to_play_, idx);
// update_history_actions(to_play_, idx);
// fmt::println("update_history_actions");
PlayerId
player
=
to_play_
;
PlayerId
player
=
to_play_
;
...
@@ -2153,6 +2155,10 @@ private:
...
@@ -2153,6 +2155,10 @@ private:
ReplayWriteInt8
(
1
);
ReplayWriteInt8
(
1
);
fwrite
(
buf
,
1
,
1
,
fp_
);
fwrite
(
buf
,
1
,
1
,
fp_
);
break
;
break
;
case
MSG_SELECT_COUNTER
:
ReplayWriteInt8
(
2
);
fwrite
(
buf
,
2
,
1
,
fp_
);
break
;
case
MSG_SELECT_PLACE
:
case
MSG_SELECT_PLACE
:
case
MSG_SELECT_DISFIELD
:
case
MSG_SELECT_DISFIELD
:
ReplayWriteInt8
(
3
);
ReplayWriteInt8
(
3
);
...
@@ -2183,10 +2189,13 @@ private:
...
@@ -2183,10 +2189,13 @@ private:
state
[
"obs:global_"
_
][
22
]
=
uint8_t
(
1
);
state
[
"obs:global_"
_
][
22
]
=
uint8_t
(
1
);
return
;
return
;
}
}
// fmt::println("writestate");
auto
[
spec2index
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec2index
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
// fmt::println("_set_obs_cards");
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// fmt::println("_set_obs_global");
// we can't shuffle because idx must be stable in callback
// we can't shuffle because idx must be stable in callback
if
(
n_options
>
max_options
())
{
if
(
n_options
>
max_options
())
{
...
@@ -2198,10 +2207,12 @@ private:
...
@@ -2198,10 +2207,12 @@ private:
// fmt::println("{} {}", key, val);
// fmt::println("{} {}", key, val);
// }
// }
_set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
// _set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
// fmt::println("_set_obs_actions");
n_options
=
options_
.
size
();
n_options
=
options_
.
size
();
state
[
"info:num_options"
_
]
=
n_options
;
state
[
"info:num_options"
_
]
=
n_options
;
return
;
// update h_card_ids from state
// update h_card_ids from state
auto
&
h_card_ids
=
to_play_
==
0
?
h_card_ids_0_
:
h_card_ids_1_
;
auto
&
h_card_ids
=
to_play_
==
0
?
h_card_ids_0_
:
h_card_ids_1_
;
...
@@ -2222,6 +2233,7 @@ private:
...
@@ -2222,6 +2233,7 @@ private:
}
}
h_card_ids
[
i
]
=
card_ids
;
h_card_ids
[
i
]
=
card_ids
;
}
}
// fmt::println("update h_card_ids");
// write history actions
// write history actions
...
@@ -2235,6 +2247,7 @@ private:
...
@@ -2235,6 +2247,7 @@ private:
n_action_feats
*
n1
);
n_action_feats
*
n1
);
state
[
"obs:h_actions_"
_
][
n1
].
Assign
((
uint8_t
*
)
history_actions
.
Data
(),
state
[
"obs:h_actions_"
_
][
n1
].
Assign
((
uint8_t
*
)
history_actions
.
Data
(),
n_action_feats
*
ha_p
);
n_action_feats
*
ha_p
);
// fmt::println("write history actions");
}
}
void
show_decision
(
int
idx
)
{
void
show_decision
(
int
idx
)
{
...
@@ -2307,8 +2320,8 @@ private:
...
@@ -2307,8 +2320,8 @@ private:
if
((
play_mode_
==
kSelfPlay
)
||
(
to_play_
==
ai_player_
))
{
if
((
play_mode_
==
kSelfPlay
)
||
(
to_play_
==
ai_player_
))
{
if
(
options_
.
size
()
==
1
)
{
if
(
options_
.
size
()
==
1
)
{
callback_
(
0
);
callback_
(
0
);
update_h_card_ids(to_play_, 0);
//
update_h_card_ids(to_play_, 0);
update_history_actions(to_play_, 0);
//
update_history_actions(to_play_, 0);
if
(
verbose_
)
{
if
(
verbose_
)
{
show_decision
(
0
);
show_decision
(
0
);
}
}
...
@@ -2488,6 +2501,7 @@ private:
...
@@ -2488,6 +2501,7 @@ private:
}
}
cards
.
push_back
(
c
);
cards
.
push_back
(
c
);
}
}
fmt
::
println
(
"qdp: {}, bl: {}, n: {}"
,
qdp_
,
bl
,
cards
.
size
());
return
cards
;
return
cards
;
}
}
...
@@ -2513,8 +2527,7 @@ private:
...
@@ -2513,8 +2527,7 @@ private:
return
cards
;
return
cards
;
}
}
std::vector<IdleCardSpec> read_cardlist_spec(bool extra = false,
std
::
vector
<
IdleCardSpec
>
read_cardlist_spec
(
PlayerId
player
,
bool
extra
=
false
,
bool
extra8
=
false
)
{
bool extra8 = false) {
std
::
vector
<
IdleCardSpec
>
card_specs
;
std
::
vector
<
IdleCardSpec
>
card_specs
;
auto
count
=
read_u8
();
auto
count
=
read_u8
();
card_specs
.
reserve
(
count
);
card_specs
.
reserve
(
count
);
...
@@ -2531,7 +2544,7 @@ private:
...
@@ -2531,7 +2544,7 @@ private:
data
=
read_u32
();
data
=
read_u32
();
}
}
}
}
card_specs.push_back({code, ls_to_spec(loc, seq, 0), data});
card_specs
.
push_back
({
code
,
ls_to_spec
(
loc
,
seq
,
0
,
player
!=
controller
),
data
});
}
}
return
card_specs
;
return
card_specs
;
}
}
...
@@ -2988,6 +3001,7 @@ private:
...
@@ -2988,6 +3001,7 @@ private:
if
(
verbose_
)
{
if
(
verbose_
)
{
cards
.
push_back
(
get_card
(
c
,
loc
,
seq
));
cards
.
push_back
(
get_card
(
c
,
loc
,
seq
));
}
}
// TODO: check if this is correct
revealed_
.
push_back
(
ls_to_spec
(
loc
,
seq
,
0
,
c
==
player
));
revealed_
.
push_back
(
ls_to_spec
(
loc
,
seq
,
0
,
c
==
player
));
}
}
if
(
!
verbose_
)
{
if
(
!
verbose_
)
{
...
@@ -3406,8 +3420,8 @@ private:
...
@@ -3406,8 +3420,8 @@ private:
throw
std
::
runtime_error
(
"Retry"
);
throw
std
::
runtime_error
(
"Retry"
);
}
else
if
(
msg_
==
MSG_SELECT_BATTLECMD
)
{
}
else
if
(
msg_
==
MSG_SELECT_BATTLECMD
)
{
auto
player
=
read_u8
();
auto
player
=
read_u8
();
auto activatable = read_cardlist_spec(true);
auto
activatable
=
read_cardlist_spec
(
player
,
true
);
auto attackable = read_cardlist_spec(true, true);
auto
attackable
=
read_cardlist_spec
(
player
,
true
,
true
);
bool
to_m2
=
read_u8
();
bool
to_m2
=
read_u8
();
bool
to_ep
=
read_u8
();
bool
to_ep
=
read_u8
();
...
@@ -4122,12 +4136,12 @@ private:
...
@@ -4122,12 +4136,12 @@ private:
};
};
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
int32_t
player
=
read_u8
();
int32_t
player
=
read_u8
();
auto summonable_ = read_cardlist_spec();
auto
summonable_
=
read_cardlist_spec
(
player
);
auto spsummon_ = read_cardlist_spec();
auto
spsummon_
=
read_cardlist_spec
(
player
);
auto repos_ = read_cardlist_spec();
auto
repos_
=
read_cardlist_spec
(
player
);
auto idle_mset_ = read_cardlist_spec();
auto
idle_mset_
=
read_cardlist_spec
(
player
);
auto idle_set_ = read_cardlist_spec();
auto
idle_set_
=
read_cardlist_spec
(
player
);
auto idle_activate_ = read_cardlist_spec(true);
auto
idle_activate_
=
read_cardlist_spec
(
player
,
true
);
bool
to_bp_
=
read_u8
();
bool
to_bp_
=
read_u8
();
bool
to_ep_
=
read_u8
();
bool
to_ep_
=
read_u8
();
read_u8
();
// can_shuffle
read_u8
();
// can_shuffle
...
@@ -4332,6 +4346,35 @@ private:
...
@@ -4332,6 +4346,35 @@ private:
resp_buf_
[
2
]
=
seq
;
resp_buf_
[
2
]
=
seq
;
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
};
};
}
else
if
(
msg_
==
MSG_SELECT_COUNTER
)
{
auto
player
=
read_u8
();
auto
counter_type
=
read_u16
();
auto
counter_count
=
read_u16
();
int
count
=
read_u8
();
if
(
count
!=
1
)
{
throw
std
::
runtime_error
(
"Select counter count "
+
std
::
to_string
(
count
)
+
" not implemented"
);
}
auto
pl
=
players_
[
player
];
if
(
verbose_
)
{
pl
->
notify
(
fmt
::
format
(
"Type new {} for {} card(s), separated by spaces."
,
"UNKNOWN_COUNTER"
,
count
));
}
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
auto
code
=
read_u32
();
auto
controller
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
counter
=
read_u16
();
if
(
verbose_
)
{
pl
->
notify
(
c_get_card
(
code
).
name_
+
": "
+
std
::
to_string
(
counter
));
}
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
}
// TODO: implement action
uint16_t
resp
=
counter_count
&
0xffff
;
memcpy
(
resp_buf_
,
&
resp
,
2
);
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
}
else
if
(
msg_
==
MSG_ANNOUNCE_NUMBER
)
{
}
else
if
(
msg_
==
MSG_ANNOUNCE_NUMBER
)
{
auto
player
=
read_u8
();
auto
player
=
read_u8
();
int
count
=
read_u8
();
int
count
=
read_u8
();
...
@@ -4448,6 +4491,11 @@ private:
...
@@ -4448,6 +4491,11 @@ private:
}
else
{
}
else
{
show_deck
(
0
);
show_deck
(
0
);
show_deck
(
1
);
show_deck
(
1
);
// print byte by byte
for
(
int
i
=
0
;
i
<
dp_
;
++
i
)
{
fmt
::
print
(
"{:02x} "
,
data_
[
i
]);
}
fmt
::
print
(
"
\n
"
);
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown message {}, length {}, dp {}"
,
fmt
::
format
(
"Unknown message {}, length {}, dp {}"
,
msg_to_string
(
msg_
),
dl_
,
dp_
));
msg_to_string
(
msg_
),
dl_
,
dp_
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment