Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
04e61b91
Commit
04e61b91
authored
May 21, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add YGOPro-v1
parent
f6139c17
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
6345 additions
and
1358 deletions
+6345
-1358
docs/action.md
docs/action.md
+29
-0
docs/feature_engineering.md
docs/feature_engineering.md
+23
-31
scripts/battle.py
scripts/battle.py
+1
-1
scripts/eval.py
scripts/eval.py
+1
-2
xmake.lua
xmake.lua
+18
-1
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+491
-163
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+0
-494
ygoai/utils.py
ygoai/utils.py
+6
-1
ygoenv/ygoenv/core/BS_thread_pool.h
ygoenv/ygoenv/core/BS_thread_pool.h
+0
-0
ygoenv/ygoenv/entry.py
ygoenv/ygoenv/entry.py
+4
-1
ygoenv/ygoenv/ygopro/registration.py
ygoenv/ygoenv/ygopro/registration.py
+1
-1
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+931
-663
ygoenv/ygoenv/ygopro0/__init__.py
ygoenv/ygoenv/ygopro0/__init__.py
+22
-0
ygoenv/ygoenv/ygopro0/registration.py
ygoenv/ygoenv/ygopro0/registration.py
+10
-0
ygoenv/ygoenv/ygopro0/ygopro.cpp
ygoenv/ygoenv/ygopro0/ygopro.cpp
+11
-0
ygoenv/ygoenv/ygopro0/ygopro.h
ygoenv/ygoenv/ygopro0/ygopro.h
+4797
-0
No files found.
docs/action.md
0 → 100644
View file @
04e61b91
# Action
## Types
-
Set + card
-
Reposition + card
-
Special summon + card
-
Summon Face-up Attack + card
-
Summon Face-down Defense + card
-
Attack + card
-
DirectAttack + card
-
Activate + card + effect
-
Cancel
-
Switch + phase
-
SelectPosition + card + position
-
AnnounceNumber + card + effect + number
-
SelectPlace + card + place
-
AnnounceAttrib + card + effect + attrib
## Effect
### MSG_SELECT_BATTLECMD | MSG_SELECT_IDLECMD | MSG_SELECT_CHAIN | MSG_SELECT_EFFECTYN
-
desc == 0: default effect of card
-
desc < LIMIT: system string
-
desc > LIMIT: card + effect
### MSG_SELECT_OPTION | MSG_SELECT_YESNO
-
desc == 0: error
-
desc < LIMIT: system string
-
desc > LIMIT: card + effect
docs/feature_engineering.md
View file @
04e61b91
...
@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`.
...
@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`.
## Legal Actions
## Legal Actions
-
0,1: spec index, uint16 -> 2 uint8
-
0: spec index
-
2: msg, discrete, 0: N/A, 1+: same as msg2str (15)
-
1,2: code, uint16 -> 2 uint8
-
3: act, discrete (11)
-
3: msg, discrete, 0: N/A, 1+: same as msg2str (15)
-
4: act, discrete (11)
-
N/A
-
N/A
-
t: Set
-
Set
-
r: Reposition
-
Reposition
-
c: Special Summon
-
Special Summon
-
s: Summon Face-up Attack
-
Summon Face-up Attack
-
m: Summon Face-down Defense
-
Summon Face-down Defense
-
a: Attack
-
Attack
-
v: Activate
-
DirectAttack
-
v2: Activate the second effect
-
Activate
-
v3: Activate the third effect
-
Cancel
-
v4: Activate the fourth effect
-
5: finish, discrete (2)
-
4: yes/no, discrete (3)
-
N/A
-
N/A
-
Yes
-
Finish
-
No
-
6: effect, discrete, 0: N/A
-
5
: phase, discrete (4)
-
7
: phase, discrete (4)
-
N/A
-
N/A
-
Battle (b)
-
Battle (b)
-
Main Phase 2 (m)
-
Main Phase 2 (m)
-
End Phase (e)
-
End Phase (e)
-
6: cancel, discrete (2)
-
N/A
-
Cancel
-
7: finish, discrete (2)
-
N/A
-
Finish
-
8: position, discrete, 0: N/A, same as position2str
-
8: position, discrete, 0: N/A, same as position2str
-
9: option, discrete, 0: N/A
-
9: number, discrete, 0: N/A
-
10: number, discrete, 0: N/A
-
10: place, discrete
-
11: place, discrete
-
0: N/A
-
0: N/A
-
1-7: m
-
1-7: m
-
8-15: s
-
8-15: s
-
16-22: om
-
16-22: om
-
23-30: os
-
23-30: os
-
1
2
: attribute, discrete, 0: N/A, same as attribute2id
-
1
1
: attribute, discrete, 0: N/A, same as attribute2id
## History Actions
## History Actions
-
0,1: card id, uint16 -> 2 uint8
-
0,1: card id, uint16 -> 2 uint8
-
2-12 same as legal actions
-
2-11 same as legal actions
-
13: player, discrete, 0: me, 1: oppo
-
12: turn, discrete, trunc to 3
-
14: turn, discrete, trunc to 3
-
13: phase, discrete (10)
scripts/battle.py
View file @
04e61b91
...
@@ -18,7 +18,7 @@ import flax
...
@@ -18,7 +18,7 @@ import flax
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent
2
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
@
dataclass
@
dataclass
...
...
scripts/eval.py
View file @
04e61b91
...
@@ -135,7 +135,7 @@ if __name__ == "__main__":
...
@@ -135,7 +135,7 @@ if __name__ == "__main__":
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax
import
flax
from
ygoai.rl.jax.agent
2
import
RNNAgent
from
ygoai.rl.jax.agent
import
RNNAgent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
...
@@ -168,7 +168,6 @@ if __name__ == "__main__":
...
@@ -168,7 +168,6 @@ if __name__ == "__main__":
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
print
(
obs
)
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
...
...
xmake.lua
View file @
04e61b91
...
@@ -8,6 +8,24 @@ add_requires(
...
@@ -8,6 +8,24 @@ add_requires(
"sqlitecpp 3.2.1"
)
"sqlitecpp 3.2.1"
)
target
(
"ygopro0_ygoenv"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/ygopro0/*.cpp"
)
add_packages
(
"pybind11"
,
"fmt"
,
"glog"
,
"concurrentqueue"
,
"sqlitecpp"
,
"unordered_dense"
,
"ygopro-core"
)
set_languages
(
"c++17"
)
if
is_mode
(
"release"
)
then
set_policy
(
"build.optimization.lto"
,
true
)
add_cxxflags
(
"-march=native"
)
end
add_includedirs
(
"ygoenv"
)
after_build
(
function
(
target
)
local
install_target
=
"$(projectdir)/ygoenv/ygoenv/ygopro0"
os
.
cp
(
target
:
targetfile
(),
install_target
)
print
(
"Copy target to "
..
install_target
)
end
)
target
(
"ygopro_ygoenv"
)
target
(
"ygopro_ygoenv"
)
add_rules
(
"python.library"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/ygopro/*.cpp"
)
add_files
(
"ygoenv/ygoenv/ygopro/*.cpp"
)
...
@@ -25,7 +43,6 @@ target("ygopro_ygoenv")
...
@@ -25,7 +43,6 @@ target("ygopro_ygoenv")
print
(
"Copy target to "
..
install_target
)
print
(
"Copy target to "
..
install_target
)
end
)
end
)
target
(
"edopro_ygoenv"
)
target
(
"edopro_ygoenv"
)
add_rules
(
"python.library"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/edopro/*.cpp"
)
add_files
(
"ygoenv/ygoenv/edopro/*.cpp"
)
...
...
ygoai/rl/jax/agent.py
View file @
04e61b91
from
typing
import
Tuple
,
Union
,
Optional
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax.linen
as
nn
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
DecoderLayer
,
PositionalEncoding
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
...
@@ -14,11 +16,18 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
...
@@ -14,11 +16,18 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoder
(
nn
.
Module
):
class
ActionEncoder
(
nn
.
Module
):
channels
:
int
=
128
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
c
=
self
.
channels
c
=
self
.
channels
...
@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module):
...
@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module):
embed
=
partial
(
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
...
@@ -38,18 +46,165 @@ class ActionEncoder(nn.Module):
...
@@ -38,18 +46,165 @@ class ActionEncoder(nn.Module):
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
return
jnp
.
concatenate
([
xs
=
[
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
]
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
],
axis
=-
1
)
return
xs
class
ActionEncoderV1
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
div
=
8
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
10
,
c
//
div
)(
x
[:,
:,
1
])
x_a_finish
=
embed
(
3
,
c
//
div
//
2
)(
x
[:,
:,
2
])
x_a_effect
=
embed
(
256
,
c
//
div
*
2
)(
x
[:,
:,
3
])
x_a_phase
=
embed
(
4
,
c
//
div
//
2
)(
x
[:,
:,
4
])
x_a_position
=
embed
(
9
,
c
//
div
)(
x
[:,
:,
5
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
6
])
x_a_place
=
embed
(
31
,
c
//
div
)(
x
[:,
:,
7
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
8
])
xs
=
[
x_a_msg
,
x_a_act
,
x_a_finish
,
x_a_effect
,
x_a_phase
,
x_a_position
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x_loc
=
x1
[:,
:,
0
]
x_seq
=
x1
[:,
:,
1
]
if
self
.
version
==
0
:
x_id
=
mlp
(
(
c
,
c
//
4
),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_race
=
embed
(
27
,
c
//
16
)(
x1
[:,
:,
6
])
x_level
=
embed
(
14
,
c
//
16
)(
x1
[:,
:,
7
])
x_counter
=
embed
(
16
,
c
//
16
)(
x1
[:,
:,
8
])
x_negated
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
9
])
x_atk
=
num_transform
(
x2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
if
self
.
version
==
0
:
x_f
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
x_cards
=
jnp
.
concatenate
([
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
x_id
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
0
:
2
]))
x_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
2
:
4
]))
x_turn
=
embed
(
20
,
c
//
8
)(
x2
[:,
0
])
x_phase
=
embed
(
11
,
c
//
8
)(
x2
[:,
1
])
x_if_first
=
embed
(
2
,
c
//
8
)(
x2
[:,
2
])
x_is_my_turn
=
embed
(
2
,
c
//
8
)(
x2
[:,
3
])
x_cs
=
count_embed
(
x3
)
.
reshape
((
batch_size
,
-
1
))
x_my_hand_c
=
hand_count_embed
(
x3
[:,
1
])
x_op_hand_c
=
hand_count_embed
(
x3
[:,
8
])
x
=
jnp
.
concatenate
([
x_lp
,
x_oppo_lp
,
x_turn
,
x_phase
,
x_if_first
,
x_is_my_turn
,
x_cs
,
x_my_hand_c
,
x_op_hand_c
],
axis
=-
1
)
x
=
layer_norm
()(
x
)
return
x
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
channels
:
int
=
128
num_card_layers
:
int
=
2
num_layers
:
int
=
2
num_action_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
version
:
int
=
0
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
):
...
@@ -62,154 +217,182 @@ class Encoder(nn.Module):
...
@@ -62,154 +217,182 @@ class Encoder(nn.Module):
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
False
,
use_bias
=
Fals
e
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
Tru
e
)
embed
=
partial
(
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
count_embed
=
embed
(
100
,
c
//
16
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
hand_count_embed
=
embed
(
100
,
c
//
16
)
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
MLP
((
c
//
8
,),
last_lin
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
batch_size
=
x_cards
.
shape
[
0
]
batch_size
=
x_cards
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
valid
=
x_global
[:,
-
1
]
==
0
x_cards_1
=
x_cards
[:,
:,
:
12
]
.
astype
(
jnp
.
int32
)
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_cards_2
=
x_cards
[:,
:,
12
:]
.
astype
(
jnp
.
float32
)
x_id
=
decode_id
(
x_cards_1
[:,
:,
:
2
])
x_id
=
id_embed
(
x_id
)
x_id
=
id_embed
(
x_id
)
x_id
=
MLP
(
if
self
.
freeze_id
:
(
c
,
c
//
4
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
x_loc
=
x_cards_1
[:,
:,
2
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
x_seq
=
x_cards_1
[:,
:,
3
]
# Cards
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
f_cards
,
c_mask
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_id
,
x_cards
[:,
:,
2
:])
x_owner
=
embed
(
2
,
c
//
16
)(
x_cards_1
[:,
:,
4
])
g_card_embed
=
self
.
param
(
x_position
=
embed
(
9
,
c
//
16
)(
x_cards_1
[:,
:,
5
])
'g_card_embed'
,
x_overley
=
embed
(
2
,
c
//
16
)(
x_cards_1
[:,
:,
6
])
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
x_attribute
=
embed
(
8
,
c
//
16
)(
x_cards_1
[:,
:,
7
])
(
1
,
c
),
self
.
param_dtype
)
x_race
=
embed
(
27
,
c
//
16
)(
x_cards_1
[:,
:,
8
])
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
x_level
=
embed
(
14
,
c
//
16
)(
x_cards_1
[:,
:,
9
])
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
x_counter
=
embed
(
16
,
c
//
16
)(
x_cards_1
[:,
:,
10
])
if
self
.
card_mask
:
x_negated
=
embed
(
3
,
c
//
16
)(
x_cards_1
[:,
:,
11
])
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
x_atk
=
num_transform
(
x_cards_2
[:,
:,
0
:
2
])
c_mask
=
None
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x_cards_2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x_cards_2
[:,
:,
4
:])
x_feat
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_feat
=
layer_norm
()(
x_feat
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_feat
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
num_heads
=
max
(
2
,
c
//
128
)
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_card_layers
):
for
_
in
range
(
self
.
num_layers
):
f_cards
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
)
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
if
self
.
version
==
0
:
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
else
:
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
1
:
3
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_id
)
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
3
:
12
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
12
])
x_h_a_phase
=
embed
(
12
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_feats
.
extend
([
x_h_id
,
x_h_a_turn
,
x_h_a_phase
])
x_h_a_feats
=
jnp
.
concatenate
(
x_h_a_feats
,
axis
=-
1
)
x_h_a_feats
=
layer_norm
()(
x_h_a_feats
)
x_h_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_h_a_feats
)
if
self
.
noam
:
f_h_actions
=
LlamaEncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
rope
=
True
,
n_positions
=
64
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
else
:
x_h_a_feats
=
PositionalEncoding
()(
x_h_a_feats
)
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
# Actions
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
na_card_embed
=
self
.
param
(
na_card_embed
=
self
.
param
(
'na_card_embed'
,
'na_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
(
1
,
c
),
self
.
param_dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
],
axis
=
1
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
f_cards
=
layer_norm
()(
f_cards
)
if
self
.
version
==
0
:
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
x_global_1
=
x_global
[:,
:
4
]
.
astype
(
jnp
.
float32
)
B
=
jnp
.
arange
(
batch_size
)
x_g_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x_global_1
[:,
0
:
2
]))
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
x_g_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x_global_1
[:,
2
:
4
]))
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
x_global_2
=
x_global
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_g_turn
=
embed
(
20
,
c
//
8
)(
x_global_2
[:,
0
])
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
x_g_phase
=
embed
(
11
,
c
//
8
)(
x_global_2
[:,
1
])
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
x_g_if_first
=
embed
(
2
,
c
//
8
)(
x_global_2
[:,
2
])
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
x_g_is_my_turn
=
embed
(
2
,
c
//
8
)(
x_global_2
[:,
3
])
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
x_global_3
=
x_global
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
x_g_cs
=
count_embed
(
x_global_3
)
.
reshape
((
batch_size
,
-
1
))
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
x_g_my_hand_c
=
hand_count_embed
(
x_global_3
[:,
1
])
x_g_op_hand_c
=
hand_count_embed
(
x_global_3
[:,
8
])
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
x_global
=
jnp
.
concatenate
([
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
x_g_lp
,
x_g_oppo_lp
,
x_g_turn
,
x_g_phase
,
x_g_if_first
,
x_g_is_my_turn
,
if
not
self
.
use_history
:
x_g_cs
,
x_g_my_hand_c
,
x_g_op_hand_c
],
axis
=-
1
)
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
x_global
=
layer_norm
()(
x_global
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
else
:
f_global
=
fc_layer
(
c
)(
f_global
)
spec_index
=
x_actions
[
...
,
0
]
f_global
=
layer_norm
()(
f_global
)
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_cards
=
f_cards
+
jnp
.
expand_dims
(
f_global
,
1
)
x_a_id
=
decode_id
(
x_actions
[
...
,
1
:
3
])
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
x_a_id
=
id_embed
(
x_a_id
)
if
self
.
freeze_id
:
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
x_a_id
=
jax
.
lax
.
stop_gradient
(
x_a_id
)
B
=
jnp
.
arange
(
batch_size
)
x_a_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_a_id
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
f_a_cards
+
fc_layer
(
c
)(
layer_norm
()(
f_a_cards
))
x_a_feats
=
action_encoder
(
x_actions
[
...
,
3
:])
x_a_feats
.
append
(
x_a_id
)
x_a_feats
=
action_encoder
(
x_actions
[
...
,
2
:])
x_a_feats
=
jnp
.
concatenate
(
x_a_feats
,
axis
=-
1
)
f_actions
=
f_a_cards
+
layer_norm
()(
x_a_feats
)
x_a_feats
=
layer_norm
()(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
f_actions
=
jax
.
nn
.
silu
(
f_a_cards
)
*
x_a_feats
for
_
in
range
(
self
.
num_action_layers
):
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
f_actions
=
DecoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
=
x_a_feats
+
f_actions
f_actions
,
f_cards
,
tgt_key_padding_mask
=
a_mask
,
a_mask
=
x_actions
[:,
:,
3
]
==
0
memory_key_padding_mask
=
c_mask
)
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_actions
=
x
[
'h_actions_'
]
.
astype
(
jnp
.
int32
)
f_actions_g
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
f_g_actions
=
(
f_actions_g
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
if
self
.
use_history
:
x_h_id
=
MLP
(
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
else
:
kernel_init
=
default_fc_init2
)(
id_embed
(
x_h_id
))
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
2
:])
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
x_h_a_feats
)
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_action_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
for
_
in
range
(
self
.
num_action_layers
):
f_actions
=
DecoderLayer
(
num_heads
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
f_h_actions
,
tgt_key_padding_mask
=
a_mask
,
memory_key_padding_mask
=
h_mask
)
f_actions
=
layer_norm
()(
f_actions
)
f_s_cards_global
=
f_cards
.
mean
(
axis
=
1
)
c_mask
=
1
-
a_mask
[:,
:,
None
]
.
astype
(
f_actions
.
dtype
)
f_s_actions_ha
=
(
f_actions
*
c_mask
)
.
sum
(
axis
=
1
)
/
c_mask
.
sum
(
axis
=
1
)
f_state
=
jnp
.
concatenate
([
f_s_cards_global
,
f_s_actions_ha
],
axis
=-
1
)
return
f_actions
,
f_state
,
a_mask
,
valid
return
f_actions
,
f_state
,
a_mask
,
valid
...
@@ -219,54 +402,199 @@ class Actor(nn.Module):
...
@@ -219,54 +402,199 @@ class Actor(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
f_actions
,
mask
):
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
f_state
=
mlp
((
c
,),
use_bias
=
True
)(
f_state
)
logits
=
jnp
.
einsum
(
'bc,bnc->bn'
,
f_state
,
f_actions
)
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
EncoderLayer
(
f_actions
=
get_encoder_layer_cls
(
num_heads
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
src_key_padding_mask
=
mask
)
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
logits
=
mlp
((
c
//
4
,
1
),
use_bias
=
True
)(
f_actions
)
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
logits
[
...
,
0
]
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
return
logits
class
Critic
(
nn
.
Module
):
class
Critic
(
nn
.
Module
):
channels
:
int
=
128
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
f_state
):
def
__call__
(
self
,
f_state
):
c
=
self
.
channels
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
MLP
((
c
//
2
,
1
),
use_bias
=
True
)(
f_state
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
return
x
class
PPOAgent
(
nn
.
Module
):
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
channels
:
int
=
128
if
main
is
not
None
:
num_card_layers
:
int
=
2
rstate1
,
rstate2
=
rstate
num_action_layers
:
int
=
2
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
@
dataclass
class
ModelArgs
:
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
version
:
int
=
0
"""the version of the environment and the agent"""
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
2
num_channels
:
int
=
128
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
version
:
int
=
0
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
encoder
=
Encoder
(
encoder
=
Encoder
(
channels
=
self
.
channels
,
channels
=
c
,
num_card_layers
=
self
.
num_card_layers
,
num_layers
=
self
.
num_layers
,
num_action_layers
=
self
.
num_action_layers
,
embedding_shape
=
self
.
embedding_shape
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
version
=
self
.
version
,
)
)
actor
=
Actor
(
channels
=
self
.
channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
self
.
channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
logits
=
actor
(
f_actions
,
mask
)
value
=
critic
(
f_state
)
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
logits
,
value
,
valid
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
not
None
:
assert
switch_or_main
is
not
None
else
:
assert
not
multi_step
if
multi_step
:
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/agent2.py
deleted
100644 → 0
View file @
f6139c17
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init1
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
div
=
8
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
x_a_phase
=
embed
(
4
,
c
//
div
)(
x
[:,
:,
3
])
x_a_cancel
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
4
])
x_a_finish
=
embed
(
3
,
c
//
div
//
2
)(
x
[:,
:,
5
])
x_a_position
=
embed
(
9
,
c
//
div
//
2
)(
x
[:,
:,
6
])
x_a_option
=
embed
(
6
,
c
//
div
//
2
)(
x
[:,
:,
7
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
xs
=
[
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x_id
=
mlp
(
(
c
,
c
//
4
),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
x_loc
=
x1
[:,
:,
0
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
x_seq
=
x1
[:,
:,
1
]
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_race
=
embed
(
27
,
c
//
16
)(
x1
[:,
:,
6
])
x_level
=
embed
(
14
,
c
//
16
)(
x1
[:,
:,
7
])
x_counter
=
embed
(
16
,
c
//
16
)(
x1
[:,
:,
8
])
x_negated
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
9
])
x_atk
=
num_transform
(
x2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
x_f
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
0
:
2
]))
x_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
2
:
4
]))
x_turn
=
embed
(
20
,
c
//
8
)(
x2
[:,
0
])
x_phase
=
embed
(
11
,
c
//
8
)(
x2
[:,
1
])
x_if_first
=
embed
(
2
,
c
//
8
)(
x2
[:,
2
])
x_is_my_turn
=
embed
(
2
,
c
//
8
)(
x2
[:,
3
])
x_cs
=
count_embed
(
x3
)
.
reshape
((
batch_size
,
-
1
))
x_my_hand_c
=
hand_count_embed
(
x3
[:,
1
])
x_op_hand_c
=
hand_count_embed
(
x3
[:,
8
])
x
=
jnp
.
concatenate
([
x_lp
,
x_oppo_lp
,
x_turn
,
x_phase
,
x_if_first
,
x_is_my_turn
,
x_cs
,
x_my_hand_c
,
x_op_hand_c
],
axis
=-
1
)
x
=
layer_norm
()(
x
)
return
x
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
if
self
.
embedding_shape
is
None
:
n_embed
,
embed_dim
=
999
,
1024
elif
isinstance
(
self
.
embedding_shape
,
int
):
n_embed
,
embed_dim
=
self
.
embedding_shape
,
1024
else
:
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
batch_size
=
x_cards
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
# Cards
f_cards
,
c_mask
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_id
,
x_cards
[:,
:,
2
:])
g_card_embed
=
self
.
param
(
'g_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
if
self
.
card_mask
:
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
c_mask
=
None
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_layers
):
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
# Actions
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
na_card_embed
=
self
.
param
(
'na_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
# State
if
not
self
.
use_history
:
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
class
Actor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
f_state
=
mlp
((
c
,),
use_bias
=
True
)(
f_state
)
logits
=
jnp
.
einsum
(
'bc,bnc->bn'
,
f_state
,
f_actions
)
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
Critic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
@
dataclass
class
ModelArgs
:
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
2
num_channels
:
int
=
128
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
encoder
=
Encoder
(
channels
=
c
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
not
None
:
assert
switch_or_main
is
not
None
else
:
assert
not
multi_step
if
multi_step
:
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/utils.py
View file @
04e61b91
...
@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
...
@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
raise
FileNotFoundError
(
f
"Token deck not found: {token_deck}"
)
raise
FileNotFoundError
(
f
"Token deck not found: {token_deck}"
)
decks
[
"_tokens"
]
=
str
(
token_deck
)
decks
[
"_tokens"
]
=
str
(
token_deck
)
if
'YGOPro'
in
env_id
:
if
'YGOPro'
in
env_id
:
from
ygoenv.ygopro
import
init_module
if
env_id
==
'YGOPro-v1'
:
from
ygoenv.ygopro
import
init_module
elif
env_id
==
'YGOPro-v0'
:
from
ygoenv.ygopro0
import
init_module
else
:
raise
ValueError
(
f
"Unknown YGOPro environment: {env_id}"
)
elif
'EDOPro'
in
env_id
:
elif
'EDOPro'
in
env_id
:
from
ygoenv.edopro
import
init_module
from
ygoenv.edopro
import
init_module
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
...
...
ygoenv/ygoenv/
ygopro
/BS_thread_pool.h
→
ygoenv/ygoenv/
core
/BS_thread_pool.h
View file @
04e61b91
File moved
ygoenv/ygoenv/entry.py
View file @
04e61b91
...
@@ -18,13 +18,16 @@ try:
...
@@ -18,13 +18,16 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
try
:
import
ygoenv.ygopro0.registration
# noqa: F401
except
ImportError
:
pass
try
:
try
:
import
ygoenv.edopro.registration
# noqa: F401
import
ygoenv.edopro.registration
# noqa: F401
except
ImportError
:
except
ImportError
:
pass
pass
try
:
try
:
import
ygoenv.dummy.registration
# noqa: F401
import
ygoenv.dummy.registration
# noqa: F401
except
ImportError
:
except
ImportError
:
...
...
ygoenv/ygoenv/ygopro/registration.py
View file @
04e61b91
from
ygoenv.registration
import
register
from
ygoenv.registration
import
register
register
(
register
(
task_id
=
"YGOPro-v
0
"
,
task_id
=
"YGOPro-v
1
"
,
import_path
=
"ygoenv.ygopro"
,
import_path
=
"ygoenv.ygopro"
,
spec_cls
=
"YGOProEnvSpec"
,
spec_cls
=
"YGOProEnvSpec"
,
dm_cls
=
"YGOProDMEnvPool"
,
dm_cls
=
"YGOProDMEnvPool"
,
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
04e61b91
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#include <ankerl/unordered_dense.h>
#include <ankerl/unordered_dense.h>
#include <unordered_set>
#include <unordered_set>
#include "BS_thread_pool.h"
#include "
ygoenv/core/
BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
#include "ygoenv/core/env.h"
...
@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) {
...
@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) {
}
}
// system string
// system string
static
const
ankerl
::
unordered_dense
::
map
<
int
,
std
::
string
>
system_strings
=
{
static
const
std
::
map
<
int
,
std
::
string
>
system_strings
=
{
// announce type
{
1050
,
"Monster"
},
{
1051
,
"Spell"
},
{
1052
,
"Trap"
},
{
1054
,
"Normal"
},
{
1055
,
"Effect"
},
{
1056
,
"Fusion"
},
{
1057
,
"Ritual"
},
{
1058
,
"Trap Monsters"
},
{
1059
,
"Spirit"
},
{
1060
,
"Union"
},
{
1061
,
"Gemini"
},
{
1062
,
"Tuner"
},
{
1063
,
"Synchro"
},
{
1064
,
"Token"
},
{
1066
,
"Quick-Play"
},
{
1067
,
"Continuous"
},
{
1068
,
"Equip"
},
{
1069
,
"Field"
},
{
1070
,
"Counter"
},
{
1071
,
"Flip"
},
{
1072
,
"Toon"
},
{
1073
,
"Xyz"
},
{
1074
,
"Pendulum"
},
{
1075
,
"Special Summon"
},
{
1076
,
"Link"
},
{
1080
,
"(N/A)"
},
{
1081
,
"Extra Monster Zone"
},
// announce type end
// actions
{
1150
,
"Activate"
},
{
1151
,
"Normal Summon"
},
{
1152
,
"Special Summon"
},
{
1153
,
"Set"
},
{
1154
,
"Flip Summon"
},
{
1155
,
"To Defense"
},
{
1156
,
"To Attack"
},
{
1157
,
"Attack"
},
{
1158
,
"View"
},
{
1159
,
"S/T Set"
},
{
1160
,
"Put in Pendulum Zone"
},
{
1161
,
"Do Effect"
},
{
1162
,
"Reset Effect"
},
{
1163
,
"Pendulum Summon"
},
{
1164
,
"Synchro Summon"
},
{
1165
,
"Xyz Summon"
},
{
1166
,
"Link Summon"
},
{
1167
,
"Tribute Summon"
},
{
1168
,
"Ritual Summon"
},
{
1169
,
"Fusion Summon"
},
{
1190
,
"Add to hand"
},
{
1191
,
"Send to GY"
},
{
1192
,
"Banish"
},
{
1193
,
"Return to Deck"
},
// actions end
{
1
,
"Normal Summon"
},
{
30
,
"Replay rules apply. Continue this attack?"
},
{
30
,
"Replay rules apply. Continue this attack?"
},
{
31
,
"Attack directly with this monster?"
},
{
31
,
"Attack directly with this monster?"
},
{
80
,
"Start Step of the Battle Phase."
},
{
81
,
"During the End Phase."
},
{
90
,
"Conduct this Normal Summon without Tributing?"
},
{
91
,
"Use additional Summon?"
},
{
92
,
"Tribute your opponent's monster?"
},
{
93
,
"Continue selecting Materials?"
},
{
94
,
"Activate this card's effect now?"
},
{
95
,
"Use the effect of [%ls]?"
},
{
96
,
"Use the effect of [%ls] to avoid destruction?"
},
{
96
,
"Use the effect of [%ls] to avoid destruction?"
},
{
97
,
"Place [%ls] to a Spell & Trap Zone?"
},
{
98
,
"Tribute a monster(s) your opponent controls?"
},
{
200
,
"From [%ls], activate [%ls]?"
},
{
203
,
"Chain another card or effect?"
},
{
210
,
"Continue selecting?"
},
{
218
,
"Pay LP by Effect of [%ls], instead?"
},
{
219
,
"Detach Xyz material by Effect of [%ls], instead?"
},
{
220
,
"Remove Counter(s) by Effect of [%ls], instead?"
},
{
221
,
"On [%ls], Activate Trigger Effect of [%ls]?"
},
{
222
,
"Activate Trigger Effect?"
},
{
221
,
"On [%ls], Activate Trigger Effect of [%ls]?"
},
{
221
,
"On [%ls], Activate Trigger Effect of [%ls]?"
},
{
1190
,
"Add to hand"
},
{
1192
,
"Banish"
},
{
1621
,
"Attack Negated"
},
{
1621
,
"Attack Negated"
},
{
1622
,
"[%ls] Missed timing"
}
{
1622
,
"[%ls] Missed timing"
}
};
};
...
@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) {
...
@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) {
if
(
it
!=
system_strings
.
end
())
{
if
(
it
!=
system_strings
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
return
"system string "
+
std
::
to_string
(
desc
);
throw
std
::
runtime_error
(
fmt
::
format
(
"Cannot find system string: {}"
,
desc
));
// return "system string " + std::to_string(desc);
}
}
static
std
::
string
ltrim
(
std
::
string
s
)
{
static
std
::
string
ltrim
(
std
::
string
s
)
{
...
@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) {
...
@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) {
return
s
;
return
s
;
}
}
inline
std
::
vector
<
std
::
string
>
flag_to_usable_cardspecs
(
uint32_t
flag
,
bool
reverse
=
false
)
{
std
::
string
zone_names
[
4
]
=
{
"m"
,
"s"
,
"om"
,
"os"
};
std
::
vector
<
std
::
string
>
specs
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
uint32_t
value
=
(
flag
>>
(
j
*
8
))
&
0xff
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
bool
avail
=
(
value
&
(
1
<<
i
))
==
0
;
if
(
reverse
)
{
avail
=
!
avail
;
}
if
(
avail
)
{
specs
.
push_back
(
zone_names
[
j
]
+
std
::
to_string
(
i
+
1
));
}
}
}
return
specs
;
}
inline
std
::
string
ls_to_spec
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
)
{
inline
std
::
string
ls_to_spec
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
)
{
std
::
string
spec
;
std
::
string
spec
;
...
@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) {
...
@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) {
loc
=
LOCATION_DECK
;
loc
=
LOCATION_DECK
;
offset
=
0
;
offset
=
0
;
}
else
{
}
else
{
throw
std
::
runtime_error
(
"Invalid location"
);
std
::
string
s
=
fmt
::
format
(
"Invalid spec {}"
,
spec
);
throw
std
::
runtime_error
(
s
);
}
}
int
end
=
offset
;
int
end
=
offset
;
while
(
end
<
spec
.
size
()
&&
std
::
isdigit
(
spec
[
end
]))
{
while
(
end
<
spec
.
size
()
&&
std
::
isdigit
(
spec
[
end
]))
{
...
@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) {
...
@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) {
return
{
loc
,
seq
,
pos
};
return
{
loc
,
seq
,
pos
};
}
}
inline
uint32_t
ls_to_spec_code
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
,
bool
opponent
)
{
uint32_t
c
=
opponent
?
1
:
0
;
c
|=
(
loc
<<
8
);
c
|=
(
seq
<<
16
);
c
|=
(
pos
<<
24
);
return
c
;
}
inline
uint32_t
spec_to_code
(
const
std
::
string
&
spec
)
{
inline
std
::
tuple
<
uint8_t
,
uint8_t
,
uint8_t
,
uint8_t
>
spec_to_ls
(
uint8_t
player
,
const
std
::
string
spec
)
{
uint8_t
controller
=
player
;
int
offset
=
0
;
int
offset
=
0
;
bool
opponent
=
false
;
if
(
spec
[
0
]
==
'o'
)
{
if
(
spec
[
0
]
==
'o'
)
{
opponent
=
true
;
controller
=
1
-
player
;
offset
++
;
offset
++
;
}
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
return
ls_to_spec_code
(
loc
,
seq
,
pos
,
opponent
)
;
return
{
controller
,
loc
,
seq
,
pos
}
;
}
}
inline
std
::
string
code_to_spec
(
uint32_t
spec_code
)
{
uint8_t
loc
=
(
spec_code
>>
8
)
&
0xff
;
uint8_t
seq
=
(
spec_code
>>
16
)
&
0xff
;
uint8_t
pos
=
(
spec_code
>>
24
)
&
0xff
;
bool
opponent
=
(
spec_code
&
0xff
)
==
1
;
return
ls_to_spec
(
loc
,
seq
,
pos
,
opponent
);
}
static
std
::
tuple
<
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>>
read_decks
(
const
std
::
string
&
fp
)
{
static
std
::
tuple
<
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>>
read_decks
(
const
std
::
string
&
fp
)
{
std
::
ifstream
file
(
fp
);
std
::
ifstream
file
(
fp
);
...
@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \
...
@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \
return "unknown"; \
return "unknown"; \
}
}
static
const
ankerl
::
unordered_dense
::
map
<
int
,
uint8_t
>
system_string2id
=
make_ids
(
system_strings
,
16
);
DEFINE_X_TO_ID_FUN
(
system_string_to_id
,
system_string2id
)
static
const
std
::
map
<
uint8_t
,
std
::
string
>
location2str
=
{
static
const
std
::
map
<
uint8_t
,
std
::
string
>
location2str
=
{
{
LOCATION_DECK
,
"Deck"
},
{
LOCATION_DECK
,
"Deck"
},
{
LOCATION_HAND
,
"Hand"
},
{
LOCATION_HAND
,
"Hand"
},
...
@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id =
...
@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id =
DEFINE_X_TO_ID_FUN
(
msg_to_id
,
msg2id
)
DEFINE_X_TO_ID_FUN
(
msg_to_id
,
msg2id
)
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_act2id
=
enum
class
ActionAct
{
make_ids
({
't'
,
'r'
,
'c'
,
's'
,
'm'
,
'a'
,
'v'
},
1
);
None
,
DEFINE_X_TO_ID_FUN
(
cmd_act_to_id
,
cmd_act2id
)
Set
,
Repo
,
SpSummon
,
Summon
,
MSet
,
Attack
,
DirectAttack
,
Activate
,
Cancel
,
};
inline
std
::
string
action_act_to_string
(
ActionAct
act
)
{
switch
(
act
)
{
case
ActionAct
:
:
None
:
return
"None"
;
case
ActionAct
:
:
Set
:
return
"Set"
;
case
ActionAct
:
:
Repo
:
return
"Repo"
;
case
ActionAct
:
:
SpSummon
:
return
"SpSummon"
;
case
ActionAct
:
:
Summon
:
return
"Summon"
;
case
ActionAct
:
:
MSet
:
return
"MSet"
;
case
ActionAct
:
:
Attack
:
return
"Attack"
;
case
ActionAct
:
:
DirectAttack
:
return
"DirectAttack"
;
case
ActionAct
:
:
Activate
:
return
"Activate"
;
case
ActionAct
:
:
Cancel
:
return
"Cancel"
;
default:
return
"Unknown"
;
}
}
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_phase2id
=
enum
class
ActionPhase
{
make_ids
(
std
::
vector
<
char
>
({
'b'
,
'm'
,
'e'
}),
1
);
None
,
DEFINE_X_TO_ID_FUN
(
cmd_phase_to_id
,
cmd_phase2id
)
Battle
,
Main2
,
End
,
};
inline
std
::
string
action_phase_to_string
(
ActionPhase
phase
)
{
switch
(
phase
)
{
case
ActionPhase
:
:
None
:
return
"None"
;
case
ActionPhase
:
:
Battle
:
return
"Battle"
;
case
ActionPhase
:
:
Main2
:
return
"Main2"
;
case
ActionPhase
:
:
End
:
return
"End"
;
default:
return
"Unknown"
;
}
}
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_yesno2id
=
enum
class
ActionPlace
{
make_ids
(
std
::
vector
<
char
>
({
'y'
,
'n'
}),
1
);
None
,
DEFINE_X_TO_ID_FUN
(
cmd_yesno_to_id
,
cmd_yesno2id
)
MZone1
,
MZone2
,
MZone3
,
MZone4
,
MZone5
,
MZone6
,
MZone7
,
SZone1
,
SZone2
,
SZone3
,
SZone4
,
SZone5
,
SZone6
,
SZone7
,
SZone8
,
OpMZone1
,
OpMZone2
,
OpMZone3
,
OpMZone4
,
OpMZone5
,
OpMZone6
,
OpMZone7
,
OpSZone1
,
OpSZone2
,
OpSZone3
,
OpSZone4
,
OpSZone5
,
OpSZone6
,
OpSZone7
,
OpSZone8
,
};
static
const
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint8_t
>
cmd_place2id
=
inline
std
::
vector
<
ActionPlace
>
flag_to_usable_places
(
make_ids
(
std
::
vector
<
std
::
string
>
(
uint32_t
flag
,
bool
reverse
=
false
)
{
{
"m1"
,
"m2"
,
"m3"
,
"m4"
,
"m5"
,
"m6"
,
"m7"
,
"s1"
,
std
::
vector
<
ActionPlace
>
places
;
"s2"
,
"s3"
,
"s4"
,
"s5"
,
"s6"
,
"s7"
,
"s8"
,
"om1"
,
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
"om2"
,
"om3"
,
"om4"
,
"om5"
,
"om6"
,
"om7"
,
"os1"
,
"os2"
,
uint32_t
value
=
(
flag
>>
(
j
*
8
))
&
0xff
;
"os3"
,
"os4"
,
"os5"
,
"os6"
,
"os7"
,
"os8"
}),
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
1
);
bool
avail
=
(
value
&
(
1
<<
i
))
==
0
;
DEFINE_X_TO_ID_FUN
(
cmd_place_to_id
,
cmd_place2id
)
if
(
reverse
)
{
avail
=
!
avail
;
}
if
(
avail
)
{
ActionPlace
place
;
if
(
j
==
0
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
MZone1
));
}
else
if
(
j
==
1
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
SZone1
));
}
else
if
(
j
==
2
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
));
}
else
if
(
j
==
3
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
));
}
places
.
push_back
(
place
);
}
}
}
return
places
;
}
inline
std
::
string
action_place_to_string
(
ActionPlace
place
)
{
int
i
=
static_cast
<
int
>
(
place
);
if
(
i
==
0
)
{
return
"None"
;
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
MZone7
))
{
return
fmt
::
format
(
"m{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
SZone8
))
{
return
fmt
::
format
(
"s{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpMZone7
))
{
return
fmt
::
format
(
"om{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpSZone8
))
{
return
fmt
::
format
(
"os{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
+
1
);
}
else
{
return
"Unknown"
;
}
}
inline
std
::
pair
<
uint8_t
,
uint8_t
>
float_transform
(
int
x
)
{
inline
std
::
pair
<
uint8_t
,
uint8_t
>
float_transform
(
int
x
)
{
...
@@ -807,6 +978,89 @@ using PlayerId = uint8_t;
...
@@ -807,6 +978,89 @@ using PlayerId = uint8_t;
using
CardCode
=
uint32_t
;
using
CardCode
=
uint32_t
;
using
CardId
=
uint16_t
;
using
CardId
=
uint16_t
;
const
int
DESCRIPTION_LIMIT
=
10000
;
const
int
CARD_EFFECT_OFFSET
=
10010
;
class
LegalAction
{
public:
std
::
string
spec_
=
""
;
ActionAct
act_
=
ActionAct
::
None
;
ActionPhase
phase_
=
ActionPhase
::
None
;
bool
finish_
=
false
;
uint8_t
position_
=
0
;
int
effect_
=
-
1
;
uint8_t
number_
=
0
;
ActionPlace
place_
=
ActionPlace
::
None
;
uint8_t
attribute_
=
0
;
int
spec_index_
=
0
;
CardId
cid_
=
0
;
int
msg_
=
0
;
static
LegalAction
from_spec
(
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
act_spec
(
ActionAct
act
,
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
act_
=
act
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
finish
()
{
LegalAction
la
;
la
.
finish_
=
true
;
return
la
;
}
static
LegalAction
cancel
()
{
LegalAction
la
;
la
.
act_
=
ActionAct
::
Cancel
;
return
la
;
}
static
LegalAction
activate_spec
(
int
effect_idx
,
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
act_
=
ActionAct
::
Activate
;
la
.
effect_
=
effect_idx
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
phase
(
ActionPhase
phase
)
{
LegalAction
la
;
la
.
phase_
=
phase
;
return
la
;
}
static
LegalAction
number
(
uint8_t
number
)
{
LegalAction
la
;
la
.
number_
=
number
;
return
la
;
}
static
LegalAction
place
(
ActionPlace
place
)
{
LegalAction
la
;
la
.
place_
=
place
;
return
la
;
}
static
LegalAction
attribute
(
int
attribute
)
{
LegalAction
la
;
la
.
attribute_
=
attribute
;
return
la
;
}
};
class
SpecInfo
{
public:
uint16_t
index
;
CardId
cid
;
};
class
Card
{
class
Card
{
friend
class
YGOProEnv
;
friend
class
YGOProEnv
;
...
@@ -874,42 +1128,23 @@ public:
...
@@ -874,42 +1128,23 @@ public:
return
get_spec
(
player
!=
controler_
);
return
get_spec
(
player
!=
controler_
);
}
}
uint32_t
get_spec_code
(
PlayerId
player
)
const
{
return
ls_to_spec_code
(
location_
,
sequence_
,
position_
,
player
!=
controler_
);
}
std
::
string
get_position
()
const
{
return
position_to_string
(
position_
);
}
std
::
string
get_position
()
const
{
return
position_to_string
(
position_
);
}
std
::
string
get_effect_description
(
uint32_t
desc
,
std
::
string
get_effect_description
(
CardCode
code
,
int
effect_idx
)
const
{
bool
existing
=
false
)
const
{
if
(
code
==
0
)
{
std
::
string
s
;
return
get_system_string
(
effect_idx
);
bool
e
=
false
;
auto
code
=
code_
;
if
(
desc
>
10000
)
{
code
=
desc
>>
4
;
}
}
uint32_t
offset
=
desc
-
code_
*
16
;
if
(
effect_idx
==
0
)
{
bool
in_range
=
(
offset
>=
0
)
&&
(
offset
<
strings_
.
size
());
return
"default"
;
std
::
string
str
=
""
;
if
(
in_range
)
{
str
=
ltrim
(
strings_
[
offset
]);
}
}
if
(
in_range
||
desc
==
0
)
{
effect_idx
-=
CARD_EFFECT_OFFSET
;
if
((
desc
==
0
)
||
str
.
empty
())
{
if
(
effect_idx
<
0
)
{
s
=
"Activate "
+
name_
+
"."
;
throw
std
::
runtime_error
(
}
else
{
fmt
::
format
(
"Invalid effect index: {}"
,
effect_idx
));
s
=
name_
+
" ("
+
str
+
")"
;
e
=
true
;
}
}
else
{
s
=
get_system_string
(
desc
);
if
(
!
s
.
empty
())
{
e
=
true
;
}
}
}
if
(
existing
&&
!
e
)
{
auto
s
=
strings_
[
effect_idx
];
s
=
""
;
if
(
s
.
empty
())
{
return
"effect "
+
std
::
to_string
(
effect_idx
);
}
}
return
s
;
return
s
;
}
}
...
@@ -1222,7 +1457,7 @@ public:
...
@@ -1222,7 +1457,7 @@ public:
const
int
&
init_lp
()
const
{
return
init_lp_
;
}
const
int
&
init_lp
()
const
{
return
init_lp_
;
}
virtual
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
=
0
;
virtual
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
=
0
;
};
};
class
GreedyAI
:
public
Player
{
class
GreedyAI
:
public
Player
{
...
@@ -1232,7 +1467,7 @@ public:
...
@@ -1232,7 +1467,7 @@ public:
bool
verbose
=
false
)
bool
verbose
=
false
)
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
return
0
;
}
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
return
0
;
}
};
};
class
RandomAI
:
public
Player
{
class
RandomAI
:
public
Player
{
...
@@ -1246,8 +1481,8 @@ public:
...
@@ -1246,8 +1481,8 @@ public:
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
),
gen_
(
seed
),
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
),
gen_
(
seed
),
dist_
(
0
,
max_options
-
1
)
{}
dist_
(
0
,
max_options
-
1
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
return
dist_
(
gen_
)
%
op
tions
.
size
();
return
dist_
(
gen_
)
%
ac
tions
.
size
();
}
}
};
};
...
@@ -1258,17 +1493,17 @@ public:
...
@@ -1258,17 +1493,17 @@ public:
bool
verbose
=
false
)
bool
verbose
=
false
)
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
while
(
true
)
{
while
(
true
)
{
std
::
string
input
=
getline
();
std
::
string
input
=
getline
();
if
(
input
==
"quit"
)
{
if
(
input
==
"quit"
)
{
exit
(
0
);
exit
(
0
);
}
}
auto
it
=
std
::
find
(
options
.
begin
(),
options
.
end
(),
input
)
;
int
idx
=
std
::
stoi
(
input
)
-
1
;
if
(
i
t
!=
options
.
end
())
{
if
(
i
dx
>=
0
&&
idx
<
actions
.
size
())
{
return
std
::
distance
(
options
.
begin
(),
it
)
;
return
idx
;
}
else
{
}
else
{
fmt
::
println
(
"{} Choose from {}
"
,
duel_player_
,
options
);
fmt
::
println
(
"{} Choose from {}
actions"
,
duel_player_
,
actions
.
size
()
);
}
}
}
}
}
}
...
@@ -1286,7 +1521,7 @@ public:
...
@@ -1286,7 +1521,7 @@ public:
}
}
template
<
typename
Config
>
template
<
typename
Config
>
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
int
n_action_feats
=
1
3
;
int
n_action_feats
=
1
2
;
return
MakeDict
(
return
MakeDict
(
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
23
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
23
})),
...
@@ -1393,7 +1628,7 @@ protected:
...
@@ -1393,7 +1628,7 @@ protected:
int
turn_count_
;
int
turn_count_
;
int
msg_
;
int
msg_
;
std
::
vector
<
std
::
string
>
op
tions_
;
std
::
vector
<
LegalAction
>
legal_ac
tions_
;
PlayerId
to_play_
;
PlayerId
to_play_
;
std
::
function
<
void
(
int
)
>
callback_
;
std
::
function
<
void
(
int
)
>
callback_
;
...
@@ -1423,9 +1658,10 @@ protected:
...
@@ -1423,9 +1658,10 @@ protected:
const
int
n_history_actions_
;
const
int
n_history_actions_
;
// circular buffer for history actions
// circular buffer for history actions
TArray
<
uint8_t
>
history_actions_
;
TArray
<
uint8_t
>
history_actions_1_
;
int
ha_p_
=
0
;
TArray
<
uint8_t
>
history_actions_2_
;
std
::
vector
<
CardId
>
h_card_ids_
;
int
ha_p_1_
=
0
;
int
ha_p_2_
=
0
;
std
::
unordered_set
<
std
::
string
>
revealed_
;
std
::
unordered_set
<
std
::
string
>
revealed_
;
...
@@ -1487,8 +1723,9 @@ public:
...
@@ -1487,8 +1723,9 @@ public:
int
max_options
=
spec
.
config
[
"max_options"
_
];
int
max_options
=
spec
.
config
[
"max_options"
_
];
int
n_action_feats
=
spec
.
state_spec
[
"obs:actions_"
_
].
shape
[
1
];
int
n_action_feats
=
spec
.
state_spec
[
"obs:actions_"
_
].
shape
[
1
];
h_card_ids_
.
resize
(
max_options
);
history_actions_1_
=
TArray
<
uint8_t
>
(
Array
(
history_actions_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
history_actions_2_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
}
}
...
@@ -1560,8 +1797,10 @@ public:
...
@@ -1560,8 +1797,10 @@ public:
turn_count_
=
0
;
turn_count_
=
0
;
ms_idx_
=
-
1
;
ms_idx_
=
-
1
;
history_actions_
.
Zero
();
history_actions_1_
.
Zero
();
ha_p_
=
0
;
history_actions_2_
.
Zero
();
ha_p_1_
=
0
;
ha_p_2_
=
0
;
clock_t
_start
=
clock
();
clock_t
_start
=
clock
();
...
@@ -1720,7 +1959,7 @@ public:
...
@@ -1720,7 +1959,7 @@ public:
if
(
ms_mode_
==
0
)
{
if
(
ms_mode_
==
0
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
const
auto
&
spec
=
ms_specs_
[
j
];
const
auto
&
spec
=
ms_specs_
[
j
];
options_
.
push_back
(
spec
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
spec
)
);
}
}
}
else
{
}
else
{
ms_combs_
=
combs
;
ms_combs_
=
combs
;
...
@@ -1729,22 +1968,23 @@ public:
...
@@ -1729,22 +1968,23 @@ public:
}
}
void
handle_multi_select
()
{
void
handle_multi_select
()
{
options_
=
{}
;
legal_actions_
.
clear
()
;
if
(
ms_mode_
==
0
)
{
if
(
ms_mode_
==
0
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
if
(
ms_spec2idx_
.
find
(
ms_specs_
[
j
])
!=
ms_spec2idx_
.
end
())
{
if
(
ms_spec2idx_
.
find
(
ms_specs_
[
j
])
!=
ms_spec2idx_
.
end
())
{
options_
.
push_back
(
ms_specs_
[
j
]);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
ms_specs_
[
j
]));
}
}
}
}
if
(
ms_idx_
==
ms_max_
-
1
)
{
if
(
ms_idx_
==
ms_max_
-
1
)
{
if
(
ms_idx_
>=
ms_min_
)
{
if
(
ms_idx_
>=
ms_min_
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
}
}
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
_callback_multi_select
(
idx
,
true
);
_callback_multi_select
(
idx
,
true
);
};
};
}
else
if
(
ms_idx_
>=
ms_min_
)
{
}
else
if
(
ms_idx_
>=
ms_min_
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
_callback_multi_select
(
idx
,
false
);
_callback_multi_select
(
idx
,
false
);
};
};
...
@@ -1766,7 +2006,7 @@ public:
...
@@ -1766,7 +2006,7 @@ public:
if
(
it
!=
ms_spec2idx_
.
end
())
{
if
(
it
!=
ms_spec2idx_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
// TODO: find the root cause
// TODO
(2)
: find the root cause
// print ms_spec2idx
// print ms_spec2idx
show_deck
(
0
);
show_deck
(
0
);
show_deck
(
1
);
show_deck
(
1
);
...
@@ -1783,11 +2023,15 @@ public:
...
@@ -1783,11 +2023,15 @@ public:
}
}
void
_callback_multi_select_2
(
int
idx
)
{
void
_callback_multi_select_2
(
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
idx
=
get_ms_spec_idx
(
option
);
idx
=
get_ms_spec_idx
(
action
.
spec_
);
if
(
idx
==
-
1
)
{
if
(
idx
==
-
1
)
{
// TODO: find the root cause
// TODO(2): find the root cause
fmt
::
println
(
"options: {}, idx: {}, option: {}"
,
options_
,
idx
,
option
);
std
::
vector
<
std
::
string
>
specs
;
for
(
const
auto
&
la
:
legal_actions_
)
{
specs
.
push_back
(
la
.
spec_
);
}
fmt
::
println
(
"specs: {}, idx: {}, spec: {}"
,
specs
,
idx
,
action
.
spec_
);
throw
std
::
runtime_error
(
"Spec not found"
);
throw
std
::
runtime_error
(
"Spec not found"
);
}
}
ms_r_idxs_
.
push_back
(
idx
);
ms_r_idxs_
.
push_back
(
idx
);
...
@@ -1814,7 +2058,7 @@ public:
...
@@ -1814,7 +2058,7 @@ public:
}
}
for
(
auto
&
i
:
comb
)
{
for
(
auto
&
i
:
comb
)
{
const
auto
&
spec
=
ms_specs_
[
i
];
const
auto
&
spec
=
ms_specs_
[
i
];
options_
.
push_back
(
spec
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
spec
)
);
}
}
}
}
...
@@ -1831,17 +2075,21 @@ public:
...
@@ -1831,17 +2075,21 @@ public:
}
}
void
_callback_multi_select
(
int
idx
,
bool
finish
)
{
void
_callback_multi_select
(
int
idx
,
bool
finish
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
// fmt::println("Select card: {}, finish: {}", option, finish);
// fmt::println("Select card: {}, finish: {}", option, finish);
if
(
option
==
"f"
)
{
if
(
action
.
finish_
)
{
finish
=
true
;
finish
=
true
;
}
else
{
}
else
{
idx
=
get_ms_spec_idx
(
option
);
idx
=
get_ms_spec_idx
(
action
.
spec_
);
if
(
idx
!=
-
1
)
{
if
(
idx
!=
-
1
)
{
ms_r_idxs_
.
push_back
(
idx
);
ms_r_idxs_
.
push_back
(
idx
);
}
else
{
}
else
{
// TODO: find the root cause
// TODO(2): find the root cause
fmt
::
println
(
"options: {}, idx: {}, option: {}"
,
options_
,
idx
,
option
);
std
::
vector
<
std
::
string
>
specs
;
for
(
const
auto
&
la
:
legal_actions_
)
{
specs
.
push_back
(
la
.
spec_
);
}
fmt
::
println
(
"specs: {}, idx: {}, spec: {}"
,
specs
,
idx
,
action
.
spec_
);
ms_idx_
=
-
1
;
ms_idx_
=
-
1
;
resp_buf_
[
0
]
=
ms_min_
;
resp_buf_
[
0
]
=
ms_min_
;
for
(
int
i
=
0
;
i
<
ms_min_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ms_min_
;
++
i
)
{
...
@@ -1860,27 +2108,27 @@ public:
...
@@ -1860,27 +2108,27 @@ public:
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
}
else
{
}
else
{
ms_idx_
++
;
ms_idx_
++
;
ms_spec2idx_
.
erase
(
option
);
ms_spec2idx_
.
erase
(
action
.
spec_
);
}
}
}
}
void
update_h_card_ids
(
PlayerId
player
,
int
idx
)
{
void
update_history_actions
(
PlayerId
player
,
const
LegalAction
&
action
)
{
h_card_ids_
[
idx
]
=
parse_card_id
(
options_
[
idx
],
player
);
if
(
action
.
act_
==
ActionAct
::
Cancel
)
{
}
void
update_history_actions
(
PlayerId
player
,
int
idx
)
{
if
((
msg_
==
MSG_SELECT_CHAIN
)
&
(
options_
[
idx
][
0
]
==
'c'
))
{
return
;
return
;
}
}
ha_p_
--
;
auto
&
ha_p
=
player
==
0
?
ha_p_1_
:
ha_p_2_
;
if
(
ha_p_
<
0
)
{
auto
&
history_actions
=
player
==
0
?
history_actions_1_
:
history_actions_2_
;
ha_p_
=
n_history_actions_
-
1
;
ha_p
--
;
if
(
ha_p
<
0
)
{
ha_p
=
n_history_actions_
-
1
;
}
}
history_actions_
[
ha_p_
].
Zero
();
history_actions
[
ha_p
].
Zero
();
_set_obs_action
(
history_actions_
,
ha_p_
,
msg_
,
options_
[
idx
],
{},
_set_obs_action
(
history_actions
,
ha_p
,
action
);
h_card_ids_
[
idx
]);
// Spec index not available in history actions
history_actions_
[
ha_p_
](
13
)
=
static_cast
<
uint8_t
>
(
player
);
history_actions
[
ha_p
](
0
)
=
0
;
history_actions_
[
ha_p_
](
14
)
=
static_cast
<
uint8_t
>
(
turn_count_
);
// history_actions[ha_p](12) = static_cast<uint8_t>(player);
history_actions
[
ha_p
](
12
)
=
static_cast
<
uint8_t
>
(
turn_count_
);
history_actions
[
ha_p
](
13
)
=
static_cast
<
uint8_t
>
(
phase_to_id
(
current_phase_
));
}
}
void
show_deck
(
const
std
::
vector
<
CardCode
>
&
deck
,
const
std
::
string
&
prefix
)
const
{
void
show_deck
(
const
std
::
vector
<
CardCode
>
&
deck
,
const
std
::
string
&
prefix
)
const
{
...
@@ -1910,18 +2158,18 @@ public:
...
@@ -1910,18 +2158,18 @@ public:
}
}
void
show_history_actions
(
PlayerId
player
)
const
{
void
show_history_actions
(
PlayerId
player
)
const
{
const
auto
&
ha
=
history_actions
_
;
const
auto
&
ha
=
player
==
0
?
history_actions_1_
:
history_actions_2
_
;
// print card ids of history actions
// print card ids of history actions
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
fmt
::
print
(
"history {}
\n
"
,
i
);
fmt
::
print
(
"history {}
\n
"
,
i
);
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
2
));
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
3
));
int
msg
=
_msgs
[
msg_id
-
1
];
int
msg
=
_msgs
[
msg_id
-
1
];
fmt
::
print
(
"msg: {},"
,
msg_to_string
(
msg
));
fmt
::
print
(
"msg: {},"
,
msg_to_string
(
msg
));
uint8_t
v1
=
ha
(
i
,
0
);
uint8_t
v1
=
ha
(
i
,
1
);
uint8_t
v2
=
ha
(
i
,
1
);
uint8_t
v2
=
ha
(
i
,
2
);
CardId
card_id
=
(
static_cast
<
CardId
>
(
v1
)
<<
8
)
+
static_cast
<
CardId
>
(
v2
);
CardId
card_id
=
(
static_cast
<
CardId
>
(
v1
)
<<
8
)
+
static_cast
<
CardId
>
(
v2
);
fmt
::
print
(
" {};"
,
card_id
);
fmt
::
print
(
" {};"
,
card_id
);
for
(
int
j
=
3
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
for
(
int
j
=
4
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
fmt
::
print
(
" {}"
,
uint8_t
(
ha
(
i
,
j
)));
fmt
::
print
(
" {}"
,
uint8_t
(
ha
(
i
,
j
)));
}
}
fmt
::
print
(
"
\n
"
);
fmt
::
print
(
"
\n
"
);
...
@@ -1933,7 +2181,7 @@ public:
...
@@ -1933,7 +2181,7 @@ public:
int
idx
=
action
[
"action"
_
];
int
idx
=
action
[
"action"
_
];
callback_
(
idx
);
callback_
(
idx
);
update_history_actions
(
to_play_
,
idx
);
update_history_actions
(
to_play_
,
legal_actions_
[
idx
]
);
PlayerId
player
=
to_play_
;
PlayerId
player
=
to_play_
;
...
@@ -2012,10 +2260,10 @@ public:
...
@@ -2012,10 +2260,10 @@ public:
}
}
private:
private:
using
SpecIn
dex
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint16_t
>
;
using
SpecIn
fos
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
SpecInfo
>
;
std
::
tuple
<
SpecIn
dex
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
std
::
tuple
<
SpecIn
fos
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
SpecIn
dex
spec2index
;
SpecIn
fos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
std
::
vector
<
int
>
loc_n_cards
;
int
offset
=
0
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
...
@@ -2054,18 +2302,23 @@ private:
...
@@ -2054,18 +2302,23 @@ private:
hide
=
false
;
hide
=
false
;
}
}
}
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
_set_obs_card_
(
f_cards
,
offset
,
c
,
hide
);
_set_obs_card_
(
f_cards
,
offset
,
c
,
hide
);
offset
++
;
offset
++
;
spec2index
[
spec
]
=
static_cast
<
uint16_t
>
(
offset
);
spec_infos
[
spec
]
=
{
static_cast
<
uint16_t
>
(
offset
),
card_id
};
}
}
}
}
}
}
}
}
return
{
spec
2index
,
loc_n_cards
};
return
{
spec
_infos
,
loc_n_cards
};
}
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
bool
hide
)
{
bool
hide
,
CardId
card_id
=
0
)
{
// check offset exceeds max_cards
// check offset exceeds max_cards
uint8_t
location
=
c
.
location_
;
uint8_t
location
=
c
.
location_
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
...
@@ -2077,7 +2330,6 @@ private:
...
@@ -2077,7 +2330,6 @@ private:
}
}
if
(
!
hide
)
{
if
(
!
hide
)
{
auto
card_id
=
c_get_card_id
(
c
.
code_
);
f_cards
(
offset
,
0
)
=
static_cast
<
uint8_t
>
(
card_id
>>
8
);
f_cards
(
offset
,
0
)
=
static_cast
<
uint8_t
>
(
card_id
>>
8
);
f_cards
(
offset
,
1
)
=
static_cast
<
uint8_t
>
(
card_id
&
0xff
);
f_cards
(
offset
,
1
)
=
static_cast
<
uint8_t
>
(
card_id
&
0xff
);
}
}
...
@@ -2148,17 +2400,10 @@ private:
...
@@ -2148,17 +2400,10 @@ private:
}
}
}
}
void
_set_obs_action_spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
SpecInfo
&
find_spec_info
(
SpecInfos
&
spec_infos
,
const
std
::
string
&
spec
)
{
const
std
::
string
&
spec
,
auto
it
=
spec_infos
.
find
(
spec
);
const
SpecIndex
&
spec2index
,
if
(
it
==
spec_infos
.
end
())
{
CardId
card_id
=
0
)
{
// TODO(2): find the root cause
uint16_t
idx
;
if
(
spec2index
.
empty
())
{
idx
=
card_id
;
}
else
{
auto
it
=
spec2index
.
find
(
spec
);
if
(
it
==
spec2index
.
end
())
{
// TODO: find the root cause
// print spec2index
// print spec2index
show_deck
(
0
);
show_deck
(
0
);
show_deck
(
1
);
show_deck
(
1
);
...
@@ -2166,135 +2411,111 @@ private:
...
@@ -2166,135 +2411,111 @@ private:
show_turn
();
show_turn
();
fmt
::
println
(
"MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}"
,
ms_idx_
,
ms_mode_
,
ms_min_
,
ms_max_
,
ms_must_
,
ms_specs_
,
ms_combs_
);
fmt
::
println
(
"MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}"
,
ms_idx_
,
ms_mode_
,
ms_min_
,
ms_max_
,
ms_must_
,
ms_specs_
,
ms_combs_
);
fmt
::
println
(
"Spec: {}, Spec2index:"
,
spec
);
fmt
::
println
(
"Spec: {}, Spec2index:"
,
spec
);
for
(
auto
&
[
k
,
v
]
:
spec
2index
)
{
for
(
auto
&
[
k
,
v
]
:
spec
_infos
)
{
fmt
::
print
(
"{}: {}
, "
,
k
,
v
);
fmt
::
print
(
"{}: {}
{}, "
,
k
,
v
.
index
,
v
.
cid
);
}
}
fmt
::
print
(
"
\n
"
);
fmt
::
print
(
"
\n
"
);
// throw std::runtime_error("Spec not found: " + spec);
// throw std::runtime_error("Spec not found: " + spec);
idx
=
1
;
spec_infos
[
spec
]
=
{
1
,
1
};
}
else
{
return
spec_infos
[
spec
];
idx
=
it
->
second
;
}
}
}
feat
(
i
,
0
)
=
static_cast
<
uint8_t
>
(
idx
>>
8
);
return
it
->
second
;
feat
(
i
,
1
)
=
static_cast
<
uint8_t
>
(
idx
&
0xff
);
}
void
_set_obs_action_msg
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
)
{
feat
(
i
,
2
)
=
msg_to_id
(
msg
);
}
}
void
_set_obs_action_
act
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
act
,
void
_set_obs_action_
spec
(
uint8_t
act_offset
=
0
)
{
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
idx
)
{
feat
(
i
,
3
)
=
cmd_act_to_id
(
act
)
+
act_offset
;
feat
(
i
,
0
)
=
static_cast
<
uint8_t
>
(
idx
)
;
}
}
void
_set_obs_action_yesno
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
yesno
)
{
void
_set_obs_action_card_id
(
feat
(
i
,
4
)
=
cmd_yesno_to_id
(
yesno
);
TArray
<
uint8_t
>
&
feat
,
int
i
,
CardId
cid
)
{
feat
(
i
,
1
)
=
static_cast
<
uint8_t
>
(
cid
>>
8
);
feat
(
i
,
2
)
=
static_cast
<
uint8_t
>
(
cid
&
0xff
);
}
}
void
_set_obs_action_
phase
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
phase
)
{
void
_set_obs_action_
msg
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
)
{
feat
(
i
,
5
)
=
cmd_phase_to_id
(
phase
);
feat
(
i
,
3
)
=
msg_to_id
(
msg
);
}
}
void
_set_obs_action_
cancel
(
TArray
<
uint8_t
>
&
feat
,
int
i
)
{
void
_set_obs_action_
act
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionAct
act
)
{
feat
(
i
,
6
)
=
1
;
feat
(
i
,
4
)
=
static_cast
<
uint8_t
>
(
act
)
;
}
}
void
_set_obs_action_finish
(
TArray
<
uint8_t
>
&
feat
,
int
i
)
{
void
_set_obs_action_finish
(
TArray
<
uint8_t
>
&
feat
,
int
i
)
{
feat
(
i
,
7
)
=
1
;
feat
(
i
,
5
)
=
1
;
}
void
_set_obs_action_effect
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
effect
)
{
// 0: None
// 1: default
// 2-15: card effect
// 16+: system
if
(
effect
==
-
1
)
{
effect
=
0
;
}
else
if
(
effect
==
0
)
{
effect
=
1
;
}
else
if
(
effect
>=
CARD_EFFECT_OFFSET
)
{
effect
=
effect
-
CARD_EFFECT_OFFSET
+
2
;
}
else
{
effect
=
system_string_to_id
(
effect
);
}
feat
(
i
,
6
)
=
static_cast
<
uint8_t
>
(
effect
);
}
}
void
_set_obs_action_position
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
position
)
{
void
_set_obs_action_phase
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionPhase
phase
){
position
=
1
<<
(
position
-
'1'
);
feat
(
i
,
7
)
=
static_cast
<
uint8_t
>
(
phase
);
feat
(
i
,
8
)
=
position_to_id
(
position
);
}
}
void
_set_obs_action_
option
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
op
tion
)
{
void
_set_obs_action_
position
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
posi
tion
)
{
feat
(
i
,
9
)
=
option
-
'0'
;
feat
(
i
,
8
)
=
position_to_id
(
position
)
;
}
}
void
_set_obs_action_number
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
number
)
{
void
_set_obs_action_number
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
number
)
{
feat
(
i
,
10
)
=
number
-
'0'
;
feat
(
i
,
9
)
=
number
;
}
}
void
_set_obs_action_place
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
std
::
string
&
spec
)
{
void
_set_obs_action_place
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionPlace
place
)
{
feat
(
i
,
1
1
)
=
cmd_place_to_id
(
spec
);
feat
(
i
,
1
0
)
=
static_cast
<
uint8_t
>
(
place
);
}
}
void
_set_obs_action_attrib
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
attrib
)
{
void
_set_obs_action_attrib
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
attrib
)
{
feat
(
i
,
1
2
)
=
attribute_to_id
(
attrib
);
feat
(
i
,
1
1
)
=
attribute_to_id
(
attrib
);
}
}
void
_set_obs_action
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
,
void
_set_obs_action
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
LegalAction
&
action
)
{
const
std
::
string
&
option
,
const
SpecIndex
&
spec2index
,
auto
msg
=
action
.
msg_
;
CardId
card_id
)
{
_set_obs_action_msg
(
feat
,
i
,
msg
);
_set_obs_action_msg
(
feat
,
i
,
msg
);
if
(
msg
==
MSG_SELECT_IDLECMD
)
{
_set_obs_action_card_id
(
feat
,
i
,
action
.
cid_
);
if
(
option
==
"b"
||
option
==
"e"
)
{
if
(
msg
==
MSG_SELECT_CARD
||
msg
==
MSG_SELECT_TRIBUTE
||
_set_obs_action_phase
(
feat
,
i
,
option
[
0
]);
msg
==
MSG_SELECT_SUM
||
msg
==
MSG_SELECT_UNSELECT_CARD
)
{
}
else
{
if
(
action
.
finish_
)
{
auto
act
=
option
[
0
];
auto
spec
=
option
.
substr
(
2
);
uint8_t
offset
=
0
;
int
n
=
spec
.
size
();
if
(
act
==
'v'
&&
std
::
isalpha
(
spec
[
n
-
1
]))
{
offset
=
spec
[
n
-
1
]
-
'a'
;
spec
=
spec
.
substr
(
0
,
n
-
1
);
}
_set_obs_action_act
(
feat
,
i
,
act
,
offset
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
}
}
else
if
(
msg
==
MSG_SELECT_CHAIN
)
{
if
(
option
[
0
]
==
'c'
)
{
_set_obs_action_cancel
(
feat
,
i
);
}
else
{
char
act
=
'v'
;
auto
spec
=
option
;
uint8_t
offset
=
0
;
auto
n
=
spec
.
size
();
if
(
std
::
isalpha
(
spec
[
n
-
1
]))
{
offset
=
spec
[
n
-
1
]
-
'a'
;
spec
=
spec
.
substr
(
0
,
n
-
1
);
}
_set_obs_action_act
(
feat
,
i
,
act
,
offset
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
}
}
else
if
(
msg
==
MSG_SELECT_CARD
||
msg
==
MSG_SELECT_TRIBUTE
||
msg
==
MSG_SELECT_SUM
||
msg
==
MSG_SELECT_UNSELECT_CARD
)
{
if
(
option
[
0
]
==
'f'
)
{
_set_obs_action_finish
(
feat
,
i
);
_set_obs_action_finish
(
feat
,
i
);
}
else
{
}
else
{
_set_obs_action_spec
(
feat
,
i
,
option
,
spec2index
,
card_id
);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
}
}
}
else
if
(
msg
==
MSG_SELECT_POSITION
)
{
}
else
if
(
msg
==
MSG_SELECT_POSITION
)
{
_set_obs_action_position
(
feat
,
i
,
option
[
0
]
);
_set_obs_action_position
(
feat
,
i
,
action
.
position_
);
}
else
if
(
msg
==
MSG_SELECT_EFFECTYN
)
{
}
else
if
(
msg
==
MSG_SELECT_EFFECTYN
)
{
auto
spec
=
option
.
substr
(
2
);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
_set_obs_action_yesno
(
feat
,
i
,
option
[
0
]);
}
else
if
(
msg
==
MSG_SELECT_YESNO
||
msg
==
MSG_SELECT_OPTION
)
{
}
else
if
(
msg
==
MSG_SELECT_YESNO
)
{
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_yesno
(
feat
,
i
,
option
[
0
]);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
}
else
if
(
msg
==
MSG_SELECT_BATTLECMD
)
{
}
else
if
(
if
(
option
==
"m"
||
option
==
"e"
)
{
msg
==
MSG_SELECT_BATTLECMD
||
_set_obs_action_phase
(
feat
,
i
,
option
[
0
]);
msg
==
MSG_SELECT_IDLECMD
||
}
else
{
msg
==
MSG_SELECT_CHAIN
)
{
auto
act
=
option
[
0
];
_set_obs_action_phase
(
feat
,
i
,
action
.
phase_
);
auto
spec
=
option
.
substr
(
2
);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
_set_obs_action_act
(
feat
,
i
,
act
);
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
}
}
else
if
(
msg
==
MSG_SELECT_OPTION
)
{
_set_obs_action_option
(
feat
,
i
,
option
[
0
]);
}
else
if
(
msg
==
MSG_SELECT_PLACE
||
msg_
==
MSG_SELECT_DISFIELD
)
{
}
else
if
(
msg
==
MSG_SELECT_PLACE
||
msg_
==
MSG_SELECT_DISFIELD
)
{
_set_obs_action_place
(
feat
,
i
,
option
);
_set_obs_action_place
(
feat
,
i
,
action
.
place_
);
}
else
if
(
msg
==
MSG_ANNOUNCE_ATTRIB
)
{
}
else
if
(
msg
==
MSG_ANNOUNCE_ATTRIB
)
{
_set_obs_action_attrib
(
feat
,
i
,
1
<<
(
option
[
0
]
-
'1'
)
);
_set_obs_action_attrib
(
feat
,
i
,
action
.
attribute_
);
}
else
if
(
msg
==
MSG_ANNOUNCE_NUMBER
)
{
}
else
if
(
msg
==
MSG_ANNOUNCE_NUMBER
)
{
_set_obs_action_number
(
feat
,
i
,
option
[
0
]
);
_set_obs_action_number
(
feat
,
i
,
action
.
number_
);
}
else
{
}
else
{
throw
std
::
runtime_error
(
"Unsupported message "
+
std
::
to_string
(
msg
));
throw
std
::
runtime_error
(
"Unsupported message "
+
std
::
to_string
(
msg
));
}
}
...
@@ -2302,49 +2523,42 @@ private:
...
@@ -2302,49 +2523,42 @@ private:
CardId
spec_to_card_id
(
const
std
::
string
&
spec
,
PlayerId
player
)
{
CardId
spec_to_card_id
(
const
std
::
string
&
spec
,
PlayerId
player
)
{
int
offset
=
0
;
int
offset
=
0
;
// TODO: possible info leak
bool
opponent
=
false
;
if
(
spec
[
0
]
==
'o'
)
{
if
(
spec
[
0
]
==
'o'
)
{
player
=
1
-
player
;
player
=
1
-
player
;
opponent
=
true
;
offset
++
;
offset
++
;
}
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
return
c_get_card_id
(
get_card_code
(
player
,
loc
,
seq
));
if
(
opponent
)
{
}
bool
hidden_for_opponent
=
true
;
if
(
CardId
parse_card_id
(
const
std
::
string
&
option
,
PlayerId
player
)
{
loc
==
LOCATION_MZONE
||
loc
==
LOCATION_SZONE
||
CardId
card_id
=
0
;
loc
==
LOCATION_GRAVE
||
loc
==
LOCATION_REMOVED
)
{
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
hidden_for_opponent
=
false
;
if
(
!
(
option
==
"b"
||
option
==
"e"
))
{
auto
n
=
option
.
size
();
if
(
std
::
isalpha
(
option
[
n
-
1
]))
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
,
n
-
3
),
player
);
}
else
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
}
}
}
}
else
if
(
msg_
==
MSG_SELECT_CHAIN
)
{
if
(
revealed_
.
size
()
!=
0
)
{
if
(
option
!=
"c"
)
{
hidden_for_opponent
=
false
;
card_id
=
spec_to_card_id
(
option
,
player
);
}
}
}
else
if
(
msg_
==
MSG_SELECT_CARD
||
msg_
==
MSG_SELECT_TRIBUTE
||
if
(
hidden_for_opponent
)
{
msg_
==
MSG_SELECT_SUM
||
msg_
==
MSG_SELECT_UNSELECT_CARD
)
{
return
0
;
if
(
option
[
0
]
!=
'f'
)
{
card_id
=
spec_to_card_id
(
option
,
player
);
}
}
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
Card
c
=
get_card
(
player
,
loc
,
seq
);
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
bool
hide
=
c
.
position_
&
POS_FACEDOWN
;
}
else
if
(
msg_
==
MSG_SELECT_BATTLECMD
)
{
if
(
revealed_
.
find
(
spec
)
!=
revealed_
.
end
())
{
if
(
!
(
option
==
"m"
||
option
==
"e"
))
{
hide
=
false
;
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
}
}
}
return
c
ard_id
;
return
c
_get_card_id
(
get_card_code
(
player
,
loc
,
seq
))
;
}
}
void
_set_obs_actions
(
TArray
<
uint8_t
>
&
feat
,
const
SpecIndex
&
spec2index
,
void
_set_obs_actions
(
TArray
<
uint8_t
>
&
feat
,
const
std
::
vector
<
LegalAction
>
&
actions
)
{
int
msg
,
const
std
::
vector
<
std
::
string
>
&
options
)
{
for
(
int
i
=
0
;
i
<
actions
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
options
.
size
();
++
i
)
{
_set_obs_action
(
feat
,
i
,
actions
[
i
]);
_set_obs_action
(
feat
,
i
,
msg
,
options
[
i
],
spec2index
,
0
);
}
}
}
}
...
@@ -2451,7 +2665,7 @@ private:
...
@@ -2451,7 +2665,7 @@ private:
void
WriteState
(
float
reward
,
int
win_reason
=
0
)
{
void
WriteState
(
float
reward
,
int
win_reason
=
0
)
{
State
state
=
Allocate
();
State
state
=
Allocate
();
int
n_options
=
op
tions_
.
size
();
int
n_options
=
legal_ac
tions_
.
size
();
state
[
"reward"
_
]
=
reward
;
state
[
"reward"
_
]
=
reward
;
state
[
"info:to_play"
_
]
=
int
(
to_play_
);
state
[
"info:to_play"
_
]
=
int
(
to_play_
);
state
[
"info:is_selfplay"
_
]
=
int
(
play_mode_
==
kSelfPlay
);
state
[
"info:is_selfplay"
_
]
=
int
(
play_mode_
==
kSelfPlay
);
...
@@ -2463,62 +2677,69 @@ private:
...
@@ -2463,62 +2677,69 @@ private:
return
;
return
;
}
}
auto
[
spec
2index
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec
_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
// we can't shuffle because idx must be stable in callback
// we can't shuffle because idx must be stable in callback
if
(
n_options
>
max_options
())
{
if
(
n_options
>
max_options
())
{
op
tions_
.
resize
(
max_options
());
legal_ac
tions_
.
resize
(
max_options
());
}
}
// print spec2index
n_options
=
legal_actions_
.
size
();
// for (auto const& [key, val] : spec2index) {
// fmt::println("{} {}", key, val);
// }
_set_obs_actions
(
state
[
"obs:actions_"
_
],
spec2index
,
msg_
,
options_
);
n_options
=
options_
.
size
();
state
[
"info:num_options"
_
]
=
n_options
;
state
[
"info:num_options"
_
]
=
n_options
;
// update_h_card_ids from state
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
uint8_t
spec_index1
=
state
[
"obs:actions_"
_
](
i
,
0
)
;
auto
&
action
=
legal_actions_
[
i
]
;
uint8_t
spec_index2
=
state
[
"obs:actions_"
_
](
i
,
1
)
;
action
.
msg_
=
msg_
;
uint16_t
spec_index
=
(
static_cast
<
uint16_t
>
(
spec_index1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
spec_index2
)
;
const
auto
&
spec
=
action
.
spec_
;
if
(
spec_index
==
0
)
{
if
(
!
spec
.
empty
()
)
{
h_card_ids_
[
i
]
=
0
;
const
auto
&
spec_info
=
find_spec_info
(
spec_infos
,
spec
)
;
}
else
{
action
.
spec_index_
=
spec_info
.
index
;
uint8_t
card_id1
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
0
);
if
(
action
.
cid_
==
0
)
{
uint8_t
card_id2
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
1
)
;
action
.
cid_
=
spec_info
.
cid
;
h_card_ids_
[
i
]
=
(
static_cast
<
uint16_t
>
(
card_id1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
card_id2
);
}
}
}
}
}
_set_obs_actions
(
state
[
"obs:actions_"
_
],
legal_actions_
);
// write history actions
// write history actions
int
offset
=
n_history_actions_
-
ha_p_
;
auto
ha_p
=
to_play_
==
0
?
ha_p_1_
:
ha_p_2_
;
int
n_h_action_feats
=
history_actions_
.
Shape
()[
1
];
auto
&
history_actions
=
to_play_
==
0
?
history_actions_1_
:
history_actions_2_
;
int
offset
=
n_history_actions_
-
ha_p
;
int
n_h_action_feats
=
history_actions
.
Shape
()[
1
];
state
[
"obs:h_actions_"
_
].
Assign
(
state
[
"obs:h_actions_"
_
].
Assign
(
(
uint8_t
*
)
history_actions
_
[
ha_p_
].
Data
(),
n_h_action_feats
*
offset
);
(
uint8_t
*
)
history_actions
[
ha_p
].
Data
(),
n_h_action_feats
*
offset
);
state
[
"obs:h_actions_"
_
][
offset
].
Assign
(
state
[
"obs:h_actions_"
_
][
offset
].
Assign
(
(
uint8_t
*
)
history_actions
_
.
Data
(),
n_h_action_feats
*
ha_p_
);
(
uint8_t
*
)
history_actions
.
Data
(),
n_h_action_feats
*
ha_p
);
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
if
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
2
))
==
0
)
{
if
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
3
))
==
0
)
{
break
;
break
;
}
}
state
[
"obs:h_actions_"
_
](
i
,
13
)
=
static_cast
<
uint8_t
>
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
13
))
==
to_play_
);
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12
)) == to_play_);
int
turn_diff
=
std
::
min
(
16
,
turn_count_
-
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
1
4
)));
int
turn_diff
=
std
::
min
(
16
,
turn_count_
-
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
1
2
)));
state
[
"obs:h_actions_"
_
](
i
,
1
4
)
=
static_cast
<
uint8_t
>
(
turn_diff
);
state
[
"obs:h_actions_"
_
](
i
,
1
2
)
=
static_cast
<
uint8_t
>
(
turn_diff
);
}
}
}
}
void
show_decision
(
int
idx
)
{
void
show_decision
(
int
idx
)
{
fmt
::
println
(
"Player {} chose
\"
{}
\"
in {}"
,
to_play_
,
options_
[
idx
],
std
::
string
s
;
options_
);
const
auto
&
a
=
legal_actions_
[
idx
];
if
(
!
a
.
spec_
.
empty
())
{
s
=
a
.
spec_
;
}
else
if
(
a
.
place_
!=
ActionPlace
::
None
)
{
s
=
action_place_to_string
(
a
.
place_
);
}
else
if
(
a
.
position_
!=
0
)
{
s
=
position_to_string
(
a
.
position_
);
}
else
{
s
=
fmt
::
format
(
"{}"
,
a
);
}
fmt
::
print
(
"Player {} chose
\"
{}
\"
in {}
\n
"
,
to_play_
,
s
,
legal_actions_
);
}
}
std
::
tuple
<
std
::
vector
<
CardCode
>
,
std
::
vector
<
CardCode
>
,
std
::
string
>
std
::
tuple
<
std
::
vector
<
CardCode
>
,
std
::
vector
<
CardCode
>
,
std
::
string
>
...
@@ -2581,15 +2802,19 @@ private:
...
@@ -2581,15 +2802,19 @@ private:
handle_multi_select
();
handle_multi_select
();
}
else
{
}
else
{
handle_message
();
handle_message
();
if
(
op
tions_
.
empty
())
{
if
(
legal_ac
tions_
.
empty
())
{
continue
;
continue
;
}
}
}
}
if
((
play_mode_
==
kSelfPlay
)
||
(
to_play_
==
ai_player_
))
{
if
((
play_mode_
==
kSelfPlay
)
||
(
to_play_
==
ai_player_
))
{
if
(
op
tions_
.
size
()
==
1
)
{
if
(
legal_ac
tions_
.
size
()
==
1
)
{
callback_
(
0
);
callback_
(
0
);
update_h_card_ids
(
to_play_
,
0
);
auto
la
=
legal_actions_
[
0
];
update_history_actions
(
to_play_
,
0
);
la
.
msg_
=
msg_
;
if
(
la
.
cid_
==
0
&&
!
la
.
spec_
.
empty
())
{
la
.
cid_
=
spec_to_card_id
(
la
.
spec_
,
to_play_
);
}
update_history_actions
(
to_play_
,
la
);
if
(
verbose_
)
{
if
(
verbose_
)
{
show_decision
(
0
);
show_decision
(
0
);
}
}
...
@@ -2597,7 +2822,7 @@ private:
...
@@ -2597,7 +2822,7 @@ private:
return
;
return
;
}
}
}
else
{
}
else
{
auto
idx
=
players_
[
to_play_
]
->
think
(
op
tions_
);
auto
idx
=
players_
[
to_play_
]
->
think
(
legal_ac
tions_
);
callback_
(
idx
);
callback_
(
idx
);
if
(
verbose_
)
{
if
(
verbose_
)
{
show_decision
(
idx
);
show_decision
(
idx
);
...
@@ -2606,7 +2831,7 @@ private:
...
@@ -2606,7 +2831,7 @@ private:
}
}
}
}
done_
=
true
;
done_
=
true
;
op
tions_
.
clear
();
legal_ac
tions_
.
clear
();
}
}
uint8_t
read_u8
()
{
return
data_
[
dp_
++
];
}
uint8_t
read_u8
()
{
return
data_
[
dp_
++
];
}
...
@@ -2653,7 +2878,12 @@ private:
...
@@ -2653,7 +2878,12 @@ private:
int32_t
bl
=
YGO_QueryCard
(
pduel_
,
player
,
loc
,
seq
,
flags
,
query_buf_
);
int32_t
bl
=
YGO_QueryCard
(
pduel_
,
player
,
loc
,
seq
,
flags
,
query_buf_
);
qdp_
=
0
;
qdp_
=
0
;
if
(
bl
<=
0
)
{
if
(
bl
<=
0
)
{
throw
std
::
runtime_error
(
"[get_card] Invalid card (bl <= 0)"
);
show_deck
(
0
);
show_deck
(
1
);
show_turn
();
show_buffer
();
auto
s
=
fmt
::
format
(
"[get_card] Invalid card (bl <= 0), player: {}, loc: {}, seq: {}"
,
player
,
loc
,
seq
);
throw
std
::
runtime_error
(
s
);
}
}
uint32_t
f
=
q_read_u32
();
uint32_t
f
=
q_read_u32
();
if
(
f
==
LEN_EMPTY
)
{
if
(
f
==
LEN_EMPTY
)
{
...
@@ -2728,7 +2958,7 @@ private:
...
@@ -2728,7 +2958,7 @@ private:
c
.
attack_
=
q_read_u32
();
c
.
attack_
=
q_read_u32
();
c
.
defense_
=
q_read_u32
();
c
.
defense_
=
q_read_u32
();
// TODO: equip_target
// TODO
(2)
: equip_target
if
(
f
&
QUERY_EQUIP_CARD
)
{
if
(
f
&
QUERY_EQUIP_CARD
)
{
q_read_u32
();
q_read_u32
();
}
}
...
@@ -2744,7 +2974,7 @@ private:
...
@@ -2744,7 +2974,7 @@ private:
cards
.
push_back
(
c_
);
cards
.
push_back
(
c_
);
}
}
// TODO: counters
// TODO
(2)
: counters
uint32_t
n_counters
=
q_read_u32
();
uint32_t
n_counters
=
q_read_u32
();
for
(
int
i
=
0
;
i
<
n_counters
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_counters
;
++
i
)
{
if
(
i
==
0
)
{
if
(
i
==
0
)
{
...
@@ -2803,7 +3033,7 @@ private:
...
@@ -2803,7 +3033,7 @@ private:
auto
controller
=
read_u8
();
auto
controller
=
read_u8
();
auto
loc
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
seq
=
read_u8
();
uint32_t
data
=
-
1
;
uint32_t
data
=
0
;
if
(
extra
)
{
if
(
extra
)
{
if
(
extra8
)
{
if
(
extra8
)
{
data
=
read_u8
();
data
=
read_u8
();
...
@@ -2816,6 +3046,23 @@ private:
...
@@ -2816,6 +3046,23 @@ private:
return
card_specs
;
return
card_specs
;
}
}
std
::
tuple
<
CardCode
,
int
>
unpack_desc
(
CardCode
code
,
uint32_t
desc
)
{
if
(
desc
<
DESCRIPTION_LIMIT
)
{
return
{
0
,
desc
};
}
CardCode
code_
=
desc
>>
4
;
int
idx
=
desc
&
0xf
;
if
(
idx
<
0
||
idx
>=
14
)
{
fmt
::
print
(
"Code: {}, Code_: {}, Desc: {}
\n
"
,
code
,
code_
,
desc
);
show_deck
(
0
);
show_deck
(
1
);
show_buffer
();
show_turn
();
throw
std
::
runtime_error
(
"Invalid effect index: "
+
std
::
to_string
(
idx
));
}
return
{
code_
,
idx
+
CARD_EFFECT_OFFSET
};
}
std
::
string
cardlist_info_for_player
(
const
Card
&
card
,
PlayerId
pl
)
{
std
::
string
cardlist_info_for_player
(
const
Card
&
card
,
PlayerId
pl
)
{
std
::
string
spec
=
card
.
get_spec
(
pl
);
std
::
string
spec
=
card
.
get_spec
(
pl
);
if
(
card
.
location_
==
LOCATION_DECK
)
{
if
(
card
.
location_
==
LOCATION_DECK
)
{
...
@@ -2833,7 +3080,7 @@ private:
...
@@ -2833,7 +3080,7 @@ private:
// 3. update to_play_ and options_ if need action
// 3. update to_play_ and options_ if need action
void
handle_message
()
{
void
handle_message
()
{
msg_
=
int
(
data_
[
dp_
++
]);
msg_
=
int
(
data_
[
dp_
++
]);
op
tions_
=
{};
legal_ac
tions_
=
{};
if
(
verbose_
)
{
if
(
verbose_
)
{
fmt
::
println
(
"Message {}, length {}, dp {}"
,
msg_to_string
(
msg_
),
dl_
,
dp_
);
fmt
::
println
(
"Message {}, length {}, dp {}"
,
msg_to_string
(
msg_
),
dl_
,
dp_
);
...
@@ -3097,11 +3344,11 @@ private:
...
@@ -3097,11 +3344,11 @@ private:
uint8_t
pos
=
read_u8
();
uint8_t
pos
=
read_u8
();
uint8_t
type
=
read_u8
();
uint8_t
type
=
read_u8
();
uint32_t
value
=
read_u32
();
uint32_t
value
=
read_u32
();
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
if
(
type
==
CHINT_RACE
)
{
if
(
type
==
CHINT_RACE
)
{
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
std
::
string
races_str
=
"TODO"
;
std
::
string
races_str
=
"TODO"
;
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
...
@@ -3109,6 +3356,10 @@ private:
...
@@ -3109,6 +3356,10 @@ private:
races_str
));
races_str
));
}
}
}
else
if
(
type
==
CHINT_ATTRIBUTE
)
{
}
else
if
(
type
==
CHINT_ATTRIBUTE
)
{
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
std
::
string
attributes_str
=
"TODO"
;
std
::
string
attributes_str
=
"TODO"
;
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
...
@@ -3229,7 +3480,7 @@ private:
...
@@ -3229,7 +3480,7 @@ private:
return
;
return
;
}
}
dp_
+=
6
;
dp_
+=
6
;
// TODO: implement output
// TODO
(3)
: implement output
}
else
if
(
msg_
==
MSG_CARD_TARGET
)
{
}
else
if
(
msg_
==
MSG_CARD_TARGET
)
{
if
(
!
verbose_
)
{
if
(
!
verbose_
)
{
dp_
=
dl_
;
dp_
=
dl_
;
...
@@ -3301,7 +3552,7 @@ private:
...
@@ -3301,7 +3552,7 @@ private:
players_
[
pl
]
->
notify
(
str
);
players_
[
pl
]
->
notify
(
str
);
}
}
}
else
if
(
msg_
==
MSG_SORT_CARD
)
{
}
else
if
(
msg_
==
MSG_SORT_CARD
)
{
// TODO: implement action
// TODO
(3)
: implement action
if
(
!
verbose_
)
{
if
(
!
verbose_
)
{
dp_
=
dl_
;
dp_
=
dl_
;
resp_buf_
[
0
]
=
255
;
resp_buf_
[
0
]
=
255
;
...
@@ -3374,7 +3625,7 @@ private:
...
@@ -3374,7 +3625,7 @@ private:
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
PlayerId
op_id
=
1
-
player
;
PlayerId
op_id
=
1
-
player
;
auto
op
=
players_
[
op_id
];
auto
op
=
players_
[
op_id
];
// TODO: counter type to string
// TODO
(3)
: counter type to string
pl
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
player
)));
pl
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
player
)));
op
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
op_id
)));
op
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
op_id
)));
}
else
if
(
msg_
==
MSG_REMOVE_COUNTER
)
{
}
else
if
(
msg_
==
MSG_REMOVE_COUNTER
)
{
...
@@ -3406,7 +3657,7 @@ private:
...
@@ -3406,7 +3657,7 @@ private:
dp_
=
dl_
;
dp_
=
dl_
;
return
;
return
;
}
}
// TODO: implement output
// TODO
(3)
: implement output
dp_
=
dl_
;
dp_
=
dl_
;
}
else
if
(
msg_
==
MSG_SHUFFLE_DECK
)
{
}
else
if
(
msg_
==
MSG_SHUFFLE_DECK
)
{
if
(
!
verbose_
)
{
if
(
!
verbose_
)
{
...
@@ -3699,52 +3950,64 @@ private:
...
@@ -3699,52 +3950,64 @@ private:
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
"Battle menu:"
);
pl
->
notify
(
"Battle menu:"
);
}
}
for
(
const
auto
[
code
,
spec
,
data
]
:
activatable
)
{
for
(
const
auto
[
code_t
,
spec
,
desc
]
:
activatable
)
{
// TODO: Add effect description to indicate which effect is being activated
CardCode
code
=
code_t
;
options_
.
push_back
(
"v "
+
spec
);
if
(
code
&
0x80000000
)
{
code
&=
0x7fffffff
;
}
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
auto
c
=
c_get_card
(
code
);
auto
c
=
get_card
(
player
,
loc
,
seq
);
int
cmd_idx
=
legal_actions_
.
size
(
);
pl
->
notify
(
"v "
+
spec
+
": activate "
+
c
.
name_
+
" ("
+
std
::
string
s
=
fmt
::
format
(
std
::
to_string
(
c
.
attack_
)
+
"/"
+
"{}: activate {}({}) [{}/{}] ({})"
,
std
::
to_string
(
c
.
defense_
)
+
")"
);
cmd_idx
,
c
.
name_
,
spec
,
c
.
attack_
,
c
.
defense_
,
c
.
get_effect_description
(
code_d
,
eff_idx
)
);
}
}
}
}
for
(
const
auto
[
code
,
spec
,
data
]
:
attackable
)
{
for
(
const
auto
[
code
,
spec
,
data
]
:
attackable
)
{
// TODO: add this as feature
bool
direct_attackable
=
data
&
0x1
;
bool
direct_attackable
=
data
&
0x1
;
options_
.
push_back
(
"a "
+
spec
);
auto
act
=
direct_attackable
?
ActionAct
::
DirectAttack
:
ActionAct
::
Attack
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
act
,
spec
));
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
auto
[
controller
,
loc
,
seq
,
pos
]
=
spec_to_ls
(
player
,
spec
);
auto
c
=
get_card
(
player
,
loc
,
seq
);
auto
c
=
get_card
(
controller
,
loc
,
seq
);
std
::
string
s
;
int
cmd_idx
=
legal_actions_
.
size
();
auto
attack_str
=
direct_attackable
?
"direct attack"
:
"attack"
;
std
::
string
s
=
fmt
::
format
(
"{}: {} {}({}) "
,
cmd_idx
,
attack_str
,
c
.
name_
,
spec
);
if
(
c
.
type_
&
TYPE_LINK
)
{
if
(
c
.
type_
&
TYPE_LINK
)
{
s
=
"a "
+
spec
+
": "
+
c
.
name_
+
" ("
+
s
+=
fmt
::
format
(
"[{}]"
,
c
.
attack_
);
std
::
to_string
(
c
.
attack_
)
+
")"
;
}
else
{
}
else
{
s
=
"a "
+
spec
+
": "
+
c
.
name_
+
" ("
+
s
+=
fmt
::
format
(
"[{}/{}]"
,
c
.
attack_
,
c
.
defense_
);
std
::
to_string
(
c
.
attack_
)
+
"/"
+
std
::
to_string
(
c
.
defense_
)
+
")"
;
}
if
(
direct_attackable
)
{
s
+=
" direct attack"
;
}
else
{
s
+=
" attack"
;
}
}
pl
->
notify
(
s
);
pl
->
notify
(
s
);
}
}
}
}
if
(
to_m2
)
{
if
(
to_m2
)
{
options_
.
push_back
(
"m"
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
Main2
));
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
"m: Main phase 2."
);
pl
->
notify
(
fmt
::
format
(
"{}: Main phase 2."
,
cmd_idx
)
);
}
}
}
}
if
(
to_ep
)
{
if
(
to_ep
)
{
if
(
!
to_m2
)
{
if
(
!
to_m2
)
{
options_
.
push_back
(
"e"
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
End
));
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
"e: End phase."
);
pl
->
notify
(
fmt
::
format
(
"{}: End phase."
,
cmd_idx
)
);
}
}
}
}
}
}
...
@@ -3752,14 +4015,15 @@ private:
...
@@ -3752,14 +4015,15 @@ private:
int
n_attackables
=
attackable
.
size
();
int
n_attackables
=
attackable
.
size
();
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
,
n_activatables
,
n_attackables
,
to_ep
,
to_m2
](
int
idx
)
{
callback_
=
[
this
,
n_activatables
,
n_attackables
,
to_ep
,
to_m2
](
int
idx
)
{
const
auto
&
la
=
legal_actions_
[
idx
];
if
(
idx
<
n_activatables
)
{
if
(
idx
<
n_activatables
)
{
YGO_SetResponsei
(
pduel_
,
idx
<<
16
);
YGO_SetResponsei
(
pduel_
,
idx
<<
16
);
}
else
if
(
idx
<
(
n_activatables
+
n_attackables
))
{
}
else
if
(
idx
<
(
n_activatables
+
n_attackables
))
{
idx
=
idx
-
n_activatables
;
idx
=
idx
-
n_activatables
;
YGO_SetResponsei
(
pduel_
,
(
idx
<<
16
)
+
1
);
YGO_SetResponsei
(
pduel_
,
(
idx
<<
16
)
+
1
);
}
else
if
((
options_
[
idx
]
==
"e"
)
&&
to_ep
)
{
}
else
if
((
la
.
phase_
==
ActionPhase
::
End
)
&&
to_ep
)
{
YGO_SetResponsei
(
pduel_
,
3
);
YGO_SetResponsei
(
pduel_
,
3
);
}
else
if
((
options_
[
idx
]
==
"m"
)
&&
to_m2
)
{
}
else
if
((
la
.
phase_
==
ActionPhase
::
Main2
)
&&
to_m2
)
{
YGO_SetResponsei
(
pduel_
,
2
);
YGO_SetResponsei
(
pduel_
,
2
);
}
else
{
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
throw
std
::
runtime_error
(
"Invalid option"
);
...
@@ -3777,21 +4041,18 @@ private:
...
@@ -3777,21 +4041,18 @@ private:
std
::
vector
<
std
::
string
>
select_specs
;
std
::
vector
<
std
::
string
>
select_specs
;
select_specs
.
reserve
(
select_size
);
select_specs
.
reserve
(
select_size
);
if
(
verbose_
)
{
if
(
verbose_
)
{
std
::
vector
<
Card
>
cards
;
auto
pl
=
players_
[
player
];
pl
->
notify
(
"Select "
+
std
::
to_string
(
min
)
+
" to "
+
std
::
to_string
(
max
)
+
" cards:"
);
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
auto
code
=
read_u32
();
auto
code
=
read_u32
();
auto
loc
=
read_u32
();
auto
loc
=
read_u32
();
Card
card
=
c_get_card
(
code
);
Card
card
=
c_get_card
(
code
);
card
.
set_location
(
loc
);
card
.
set_location
(
loc
);
cards
.
push_back
(
card
);
}
auto
pl
=
players_
[
player
];
pl
->
notify
(
"Select "
+
std
::
to_string
(
min
)
+
" to "
+
std
::
to_string
(
max
)
+
" cards:"
);
for
(
const
auto
&
card
:
cards
)
{
auto
spec
=
card
.
get_spec
(
player
);
auto
spec
=
card
.
get_spec
(
player
);
select_specs
.
push_back
(
spec
);
select_specs
.
push_back
(
spec
);
pl
->
notify
(
spec
+
": "
+
card
.
name_
);
auto
s
=
fmt
::
format
(
"{}: {}({})"
,
i
+
1
,
card
.
name_
,
spec
);
pl
->
notify
(
s
);
}
}
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
...
@@ -3807,22 +4068,22 @@ private:
...
@@ -3807,22 +4068,22 @@ private:
auto
unselect_size
=
read_u8
();
auto
unselect_size
=
read_u8
();
// unselect not allowed (no regrets
!
)
// unselect not allowed (no regrets)
dp_
+=
8
*
unselect_size
;
dp_
+=
8
*
unselect_size
;
for
(
int
j
=
0
;
j
<
select_specs
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
select_specs
.
size
();
++
j
)
{
options_
.
push_back
(
select_specs
[
j
]
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
select_specs
[
j
])
);
}
}
if
(
finishable
)
{
if
(
finishable
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
}
}
// cancelable and finishable not needed
// cancelable and finishable not needed
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
if
(
options_
[
idx
]
==
"f"
)
{
if
(
legal_actions_
[
idx
].
finish_
)
{
YGO_SetResponsei
(
pduel_
,
-
1
);
YGO_SetResponsei
(
pduel_
,
-
1
);
}
else
{
}
else
{
resp_buf_
[
0
]
=
1
;
resp_buf_
[
0
]
=
1
;
...
@@ -3893,7 +4154,7 @@ private:
...
@@ -3893,7 +4154,7 @@ private:
}
}
}
}
// TODO: use this when added to history actions
// TODO
(1)
: use this when added to history actions
// if ((min == max) && (max == specs.size())) {
// if ((min == max) && (max == specs.size())) {
// resp_buf_[0] = specs.size();
// resp_buf_[0] = specs.size();
// for (int i = 0; i < specs.size(); ++i) {
// for (int i = 0; i < specs.size(); ++i) {
...
@@ -3974,7 +4235,7 @@ private:
...
@@ -3974,7 +4235,7 @@ private:
// combs = combinations_with_weight(release_params, min);
// combs = combinations_with_weight(release_params, min);
}
}
// TODO: use this when added to history actions
// TODO
(1)
: use this when added to history actions
// if (max == specs.size()) {
// if (max == specs.size()) {
// // tribute all
// // tribute all
// resp_buf_[0] = specs.size();
// resp_buf_[0] = specs.size();
...
@@ -4126,25 +4387,18 @@ private:
...
@@ -4126,25 +4387,18 @@ private:
// auto hint_timing = read_u32();
// auto hint_timing = read_u32();
// auto other_timing = read_u32();
// auto other_timing = read_u32();
std
::
vector
<
Card
>
card
s
;
std
::
vector
<
Card
Code
>
code
s
;
std
::
vector
<
uint32_t
>
descs
;
std
::
vector
<
uint32_t
>
descs
;
std
::
vector
<
uint32_t
>
spec_code
s
;
std
::
vector
<
std
::
string
>
spec
s
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
auto
et
=
read_u8
();
auto
flag
=
read_u8
();
CardCode
code
=
read_u32
();
CardCode
code
=
read_u32
();
if
(
verbose_
)
{
codes
.
push_back
(
code
);
uint32_t
loc
=
read_u32
();
PlayerId
c
=
read_u8
();
Card
card
=
c_get_card
(
code
);
uint8_t
loc
=
read_u8
();
card
.
set_location
(
loc
);
uint8_t
seq
=
read_u8
();
cards
.
push_back
(
card
);
uint8_t
pos
=
read_u8
();
spec_codes
.
push_back
(
card
.
get_spec_code
(
player
));
specs
.
push_back
(
ls_to_spec
(
loc
,
seq
,
pos
,
c
!=
player
));
}
else
{
PlayerId
c
=
read_u8
();
uint8_t
loc
=
read_u8
();
uint8_t
seq
=
read_u8
();
uint8_t
pos
=
read_u8
();
spec_codes
.
push_back
(
ls_to_spec_code
(
loc
,
seq
,
pos
,
c
!=
player
));
}
uint32_t
desc
=
read_u32
();
uint32_t
desc
=
read_u32
();
descs
.
push_back
(
desc
);
descs
.
push_back
(
desc
);
}
}
...
@@ -4168,58 +4422,42 @@ private:
...
@@ -4168,58 +4422,42 @@ private:
op
->
seen_waiting_
=
true
;
op
->
seen_waiting_
=
true
;
}
}
std
::
vector
<
int
>
chain_index
;
if
(
verbose_
)
{
ankerl
::
unordered_dense
::
map
<
uint32_t
,
int
>
chain_counts
;
pl
->
notify
(
"Select chain:"
);
ankerl
::
unordered_dense
::
map
<
uint32_t
,
int
>
chain_orders
;
std
::
vector
<
std
::
string
>
chain_specs
;
std
::
vector
<
std
::
string
>
effect_descs
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
chain_index
.
push_back
(
i
);
chain_counts
[
spec_codes
[
i
]]
+=
1
;
}
}
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
auto
spec_code
=
spec_codes
[
i
];
CardCode
code
=
codes
[
i
];
auto
cs
=
code_to_spec
(
spec_code
);
uint32_t
desc
=
descs
[
i
];
auto
chain_count
=
chain_counts
[
spec_code
];
auto
spec
=
specs
[
i
];
if
(
chain_count
>
1
)
{
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
// TODO: should use desc to indicate activate which effect
if
(
desc
==
0
)
{
cs
.
push_back
(
'a'
+
chain_orders
[
spec_code
]);
code_d
=
code
;
}
chain_orders
[
spec_code
]
++
;
chain_specs
.
push_back
(
cs
);
if
(
verbose_
)
{
const
auto
&
card
=
cards
[
i
];
effect_descs
.
push_back
(
card
.
get_effect_description
(
descs
[
i
],
true
));
}
}
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
if
(
verbose_
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
if
(
forced
)
{
pl
->
notify
(
"Select chain:"
);
}
else
{
pl
->
notify
(
"Select chain (c to cancel):"
);
}
}
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
legal_actions_
.
push_back
(
la
);
const
auto
&
effect_desc
=
effect_descs
[
i
];
if
(
verbose_
)
{
if
(
effect_desc
.
empty
())
{
auto
c
=
c_get_card
(
code
);
pl
->
notify
(
chain_specs
[
i
]
+
": "
+
cards
[
i
].
name_
);
std
::
string
s
=
fmt
::
format
(
}
else
{
"{}: {}({}) ({})"
,
pl
->
notify
(
chain_specs
[
i
]
+
" ("
+
cards
[
i
].
name_
+
i
+
1
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
"): "
+
effect_desc
);
pl
->
notify
(
s
);
}
}
}
}
}
for
(
const
auto
&
spec
:
chain_specs
)
{
options_
.
push_back
(
spec
);
}
if
(
!
forced
)
{
if
(
!
forced
)
{
options_
.
push_back
(
"c"
);
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
if
(
verbose_
)
{
pl
->
notify
(
fmt
::
format
(
"{}: cancel"
,
size
+
1
));
}
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
,
forced
](
int
idx
)
{
callback_
=
[
this
,
forced
](
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
if
(
option
==
"c"
)
{
if
(
action
.
act_
==
ActionAct
::
Cancel
)
{
if
(
forced
)
{
if
(
forced
)
{
fmt
::
print
(
"cancel not allowed in forced chain
\n
"
);
fmt
::
print
(
"cancel not allowed in forced chain
\n
"
);
YGO_SetResponsei
(
pduel_
,
0
);
YGO_SetResponsei
(
pduel_
,
0
);
...
@@ -4232,58 +4470,76 @@ private:
...
@@ -4232,58 +4470,76 @@ private:
};
};
}
else
if
(
msg_
==
MSG_SELECT_YESNO
)
{
}
else
if
(
msg_
==
MSG_SELECT_YESNO
)
{
auto
player
=
read_u8
();
auto
player
=
read_u8
();
auto
desc
=
read_u32
();
auto
[
code
,
eff_idx
]
=
unpack_desc
(
0
,
desc
);
if
(
desc
==
0
)
{
show_buffer
();
auto
s
=
fmt
::
format
(
"Unknown desc {} in select_yesno"
,
desc
);
throw
std
::
runtime_error
(
s
);
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
""
);
if
(
code
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
desc
=
read_u32
();
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
std
::
string
opt
;
std
::
string
s
;
if
(
desc
>
10000
)
{
if
(
code
==
0
)
{
auto
code
=
desc
>>
4
;
s
=
get_system_string
(
eff_idx
);
auto
card
=
c_get_card
(
code
);
}
else
{
auto
opt_idx
=
desc
&
0xf
;
Card
c
=
c_get_card
(
code
);
if
(
opt_idx
<
card
.
strings_
.
size
())
{
int
cmd_idx
=
legal_actions_
.
size
();
opt
=
card
.
strings_
[
opt_idx
];
eff_idx
-=
CARD_EFFECT_OFFSET
;
if
(
eff_idx
>=
c
.
strings_
.
size
())
{
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown effect {} of {}"
,
eff_idx
,
c
.
name_
));
}
}
if
(
opt
.
empty
())
{
auto
str
=
c
.
strings_
[
eff_idx
];
opt
=
"Unknown question from "
+
card
.
name_
+
". Yes or no?"
;
if
(
str
.
empty
())
{
str
=
"effect "
+
std
::
to_string
(
eff_idx
);
}
}
}
else
{
s
=
fmt
::
format
(
"{} ({})"
,
c
.
name_
,
str
);
opt
=
get_system_string
(
desc
);
}
}
pl
->
notify
(
opt
);
pl
->
notify
(
"1: "
+
s
);
pl
->
notify
(
"Please enter y or n."
);
pl
->
notify
(
"2: No"
);
}
else
{
dp_
+=
4
;
}
}
options_
=
{
"y"
,
"n"
};
// TODO: maybe add card id to cancel
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
YGO_SetResponsei
(
pduel_
,
1
);
YGO_SetResponsei
(
pduel_
,
1
);
}
else
if
(
idx
==
1
)
{
}
else
if
(
idx
==
1
)
{
YGO_SetResponsei
(
pduel_
,
0
);
YGO_SetResponsei
(
pduel_
,
0
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
}
}
};
};
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
auto
player
=
read_u8
();
auto
player
=
read_u8
();
std
::
string
spec
;
CardCode
code
=
read_u32
();
auto
ct
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
pos
=
read_u8
();
auto
desc
=
read_u32
();
std
::
string
spec
=
ls_to_spec
(
loc
,
seq
,
pos
,
ct
!=
player
);
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
if
(
verbose_
)
{
CardCode
code
=
read_u32
();
Card
c
=
c_get_card
(
code
);
uint32_t
loc
=
read_u32
();
Card
card
=
c_get_card
(
code
);
card
.
set_location
(
loc
);
auto
desc
=
read_u32
();
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
spec
=
card
.
get_spec
(
player
);
auto
name
=
c
.
name_
;
auto
name
=
card
.
name_
;
std
::
string
s
;
std
::
string
s
;
if
(
desc
==
0
)
{
if
(
code_d
==
0
)
{
// From [%ls], activate [%ls]?
s
=
"From "
+
card
.
get_spec
(
player
)
+
", activate "
+
name
+
"?"
;
}
else
if
(
desc
<
2048
)
{
s
=
get_system_string
(
desc
);
s
=
get_system_string
(
desc
);
std
::
string
fmt_str
=
"[%ls]"
;
std
::
string
fmt_str
=
"[%ls]"
;
auto
pos
=
find_substrs
(
s
,
fmt_str
);
auto
pos
=
find_substrs
(
s
,
fmt_str
);
...
@@ -4295,87 +4551,74 @@ private:
...
@@ -4295,87 +4551,74 @@ private:
}
else
if
(
pos
.
size
()
==
2
)
{
}
else
if
(
pos
.
size
()
==
2
)
{
auto
p1
=
pos
[
0
];
auto
p1
=
pos
[
0
];
auto
p2
=
pos
[
1
];
auto
p2
=
pos
[
1
];
s
=
s
.
substr
(
0
,
p1
)
+
card
.
get_spec
(
player
)
+
s
=
s
.
substr
(
0
,
p1
)
+
spec
+
s
.
substr
(
p1
+
fmt_str
.
size
(),
p2
-
p1
-
fmt_str
.
size
())
+
name
+
s
.
substr
(
p1
+
fmt_str
.
size
(),
p2
-
p1
-
fmt_str
.
size
())
+
name
+
s
.
substr
(
p2
+
fmt_str
.
size
());
s
.
substr
(
p2
+
fmt_str
.
size
());
}
else
{
}
else
{
throw
std
::
runtime_error
(
"Unknown effectyn desc "
+
throw
std
::
runtime_error
(
"Unknown effectyn desc "
+
std
::
to_string
(
desc
)
+
" of "
+
name
);
std
::
to_string
(
desc
)
+
" of "
+
name
);
}
}
}
else
if
(
desc
<
10000u
)
{
s
=
get_system_string
(
desc
);
}
else
{
}
else
{
CardCode
code
=
(
desc
>>
4
)
&
0x0fffffff
;
s
=
fmt
::
format
(
uint32_t
offset
=
desc
&
0xf
;
"{}({}) ({})"
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
if
(
cards_
.
find
(
code
)
!=
cards_
.
end
())
{
auto
&
card_
=
c_get_card
(
code
);
s
=
card_
.
strings_
[
offset
];
if
(
s
.
empty
())
{
s
=
"???"
;
}
}
else
{
throw
std
::
runtime_error
(
"Unknown effectyn desc "
+
std
::
to_string
(
desc
)
+
" of "
+
name
);
}
}
}
pl
->
notify
(
s
);
pl
->
notify
(
"1: "
+
s
);
pl
->
notify
(
"Please enter y or n."
);
pl
->
notify
(
"2: No"
);
}
else
{
dp_
+=
4
;
auto
c
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
pos
=
read_u8
();
dp_
+=
4
;
spec
=
ls_to_spec
(
loc
,
seq
,
pos
,
c
!=
player
);
}
}
options_
=
{
"y "
+
spec
,
"n "
+
spec
};
// TODO: maybe add card info to cancel
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
YGO_SetResponsei
(
pduel_
,
1
);
YGO_SetResponsei
(
pduel_
,
1
);
}
else
if
(
idx
==
1
)
{
}
else
if
(
idx
==
1
)
{
YGO_SetResponsei
(
pduel_
,
0
);
YGO_SetResponsei
(
pduel_
,
0
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
}
}
};
};
}
else
if
(
msg_
==
MSG_SELECT_OPTION
)
{
}
else
if
(
msg_
==
MSG_SELECT_OPTION
)
{
// TODO: add card information
auto
player
=
read_u8
();
auto
player
=
read_u8
();
auto
size
=
read_u8
();
auto
size
=
read_u8
();
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
players_
[
player
]
->
notify
(
"Select an option:"
);
pl
->
notify
(
"Select an option:"
);
}
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
auto
opt
=
read_u32
();
auto
desc
=
read_u32
();
auto
[
code
,
eff_idx
]
=
unpack_desc
(
0
,
desc
);
if
(
desc
==
0
)
{
show_buffer
();
auto
s
=
fmt
::
format
(
"Unknown desc {} in select_option"
,
desc
);
throw
std
::
runtime_error
(
s
);
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
""
);
if
(
code
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
std
::
string
s
;
std
::
string
s
;
if
(
opt
>
10000
)
{
if
(
code
==
0
)
{
CardCode
code
=
opt
>>
4
;
s
=
get_system_string
(
eff_idx
);
s
=
c_get_card
(
code
).
strings_
[
opt
&
0xf
];
}
else
{
}
else
{
s
=
get_system_string
(
opt
);
Card
c
=
c_get_card
(
code
);
int
cmd_idx
=
legal_actions_
.
size
();
eff_idx
-=
CARD_EFFECT_OFFSET
;
if
(
eff_idx
>=
c
.
strings_
.
size
())
{
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown effect {} of {}"
,
eff_idx
,
c
.
name_
));
}
auto
str
=
c
.
strings_
[
eff_idx
];
if
(
str
.
empty
())
{
str
=
"effect "
+
std
::
to_string
(
eff_idx
);
}
s
=
fmt
::
format
(
"{} ({})"
,
c
.
name_
,
str
);
}
}
std
::
string
option
=
std
::
to_string
(
i
+
1
);
players_
[
player
]
->
notify
(
std
::
to_string
(
i
+
1
)
+
": "
+
s
);
options_
.
push_back
(
option
);
pl
->
notify
(
option
+
": "
+
s
);
}
}
else
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
dp_
+=
4
;
options_
.
push_back
(
std
::
to_string
(
i
+
1
));
}
}
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
if
(
verbose_
)
{
players_
[
to_play_
]
->
notify
(
"You selected option "
+
options_
[
idx
]
+
"."
);
players_
[
1
-
to_play_
]
->
notify
(
players_
[
to_play_
]
->
nickname_
+
" selected option "
+
options_
[
idx
]
+
"."
);
}
YGO_SetResponsei
(
pduel_
,
idx
);
YGO_SetResponsei
(
pduel_
,
idx
);
};
};
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
...
@@ -4397,90 +4640,97 @@ private:
...
@@ -4397,90 +4640,97 @@ private:
pl
->
notify
(
"Select a card and action to perform."
);
pl
->
notify
(
"Select a card and action to perform."
);
}
}
for
(
const
auto
&
[
code
,
spec
,
data
]
:
summonable_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
summonable_
)
{
std
::
string
option
=
"s "
+
spec
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Summon
,
spec
));
options_
.
push_back
(
option
);
if
(
verbose_
)
{
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Summon "
+
name
+
int
cmd_idx
=
legal_actions_
.
size
();
" in face-up attack position."
);
pl
->
notify
(
fmt
::
format
(
"{}: Summon {} in face-up attack position"
,
cmd_idx
,
name
));
}
}
}
}
offset
+=
summonable_
.
size
();
offset
+=
summonable_
.
size
();
int
spsummon_offset
=
offset
;
int
spsummon_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
spsummon_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
spsummon_
)
{
std
::
string
option
=
"c "
+
spec
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
SpSummon
,
spec
));
options_
.
push_back
(
option
);
if
(
verbose_
)
{
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Special summon "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Special summon {}"
,
cmd_idx
,
name
));
}
}
}
}
offset
+=
spsummon_
.
size
();
offset
+=
spsummon_
.
size
();
int
repos_offset
=
offset
;
int
repos_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
repos_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
repos_
)
{
std
::
string
option
=
"r "
+
spec
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Repo
,
spec
));
options_
.
push_back
(
option
);
if
(
verbose_
)
{
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Reposition "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Change position of {}"
,
cmd_idx
,
name
));
}
}
}
}
offset
+=
repos_
.
size
();
offset
+=
repos_
.
size
();
int
mset_offset
=
offset
;
int
mset_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_mset_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_mset_
)
{
std
::
string
option
=
"m "
+
spec
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
MSet
,
spec
));
options_
.
push_back
(
option
);
if
(
verbose_
)
{
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Summon "
+
name
+
int
cmd_idx
=
legal_actions_
.
size
();
" in face-down defense position."
);
pl
->
notify
(
fmt
::
format
(
"{}: Summon {} in face-down defense position"
,
cmd_idx
,
name
));
}
}
}
}
offset
+=
idle_mset_
.
size
();
offset
+=
idle_mset_
.
size
();
int
set_offset
=
offset
;
int
set_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_set_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_set_
)
{
std
::
string
option
=
"t "
+
spec
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Set
,
spec
));
options_
.
push_back
(
option
);
if
(
verbose_
)
{
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Set "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Set {}"
,
cmd_idx
,
name
));
}
}
}
}
offset
+=
idle_set_
.
size
();
offset
+=
idle_set_
.
size
();
int
activate_offset
=
offset
;
int
activate_offset
=
offset
;
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
idle_activate_count
;
for
(
const
auto
&
[
code_t
,
spec
,
desc
]
:
idle_activate_
)
{
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_activate_
)
{
CardCode
code
=
code_t
;
idle_activate_count
[
spec
]
+=
1
;
if
(
code
&
0x80000000
)
{
}
code
&=
0x7fffffff
;
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
activate_count
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_activate_
)
{
// TODO: use effect description to indicate which effect to activate
std
::
string
option
=
"v "
+
spec
;
int
count
=
idle_activate_count
[
spec
];
activate_count
[
spec
]
++
;
if
(
count
>
1
)
{
option
.
push_back
(
'a'
+
activate_count
[
spec
]
-
1
);
}
}
options_
.
push_back
(
option
);
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
option
+
": "
+
auto
c
=
c_get_card
(
code
);
c_get_card
(
code
).
get_effect_description
(
data
));
int
cmd_idx
=
legal_actions_
.
size
();
std
::
string
s
=
fmt
::
format
(
"{}: Activate {}({}) ({})"
,
cmd_idx
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
pl
->
notify
(
s
);
}
}
}
}
if
(
to_bp_
)
{
if
(
to_bp_
)
{
std
::
string
cmd
=
"b"
;
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
Battle
));
options_
.
push_back
(
cmd
);
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
cmd
+
": Enter the battle phase."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Enter the battle phase."
,
cmd_idx
));
}
}
}
}
if
(
to_ep_
)
{
if
(
to_ep_
)
{
if
(
!
to_bp_
)
{
if
(
!
to_bp_
)
{
std
::
string
cmd
=
"e"
;
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
End
));
options_
.
push_back
(
cmd
);
if
(
verbose_
)
{
if
(
verbose_
)
{
pl
->
notify
(
cmd
+
": End phase."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: End phase."
,
cmd_idx
));
}
}
}
}
}
}
...
@@ -4488,104 +4738,90 @@ private:
...
@@ -4488,104 +4738,90 @@ private:
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
,
spsummon_offset
,
repos_offset
,
mset_offset
,
set_offset
,
callback_
=
[
this
,
spsummon_offset
,
repos_offset
,
mset_offset
,
set_offset
,
activate_offset
](
int
idx
)
{
activate_offset
](
int
idx
)
{
const
auto
&
option
=
options_
[
idx
];
const
auto
&
action
=
legal_actions_
[
idx
];
char
cmd
=
option
[
0
];
if
(
action
.
phase_
==
ActionPhase
::
Battle
)
{
if
(
cmd
==
'b'
)
{
YGO_SetResponsei
(
pduel_
,
6
);
YGO_SetResponsei
(
pduel_
,
6
);
}
else
if
(
cmd
==
'e'
)
{
}
else
if
(
action
.
phase_
==
ActionPhase
::
End
)
{
YGO_SetResponsei
(
pduel_
,
7
);
YGO_SetResponsei
(
pduel_
,
7
);
}
else
{
}
else
{
auto
spec
=
option
.
substr
(
2
)
;
auto
act
=
action
.
act_
;
if
(
cmd
==
's'
)
{
if
(
act
==
ActionAct
::
Summon
)
{
uint32_t
idx_
=
idx
;
uint32_t
idx_
=
idx
;
YGO_SetResponsei
(
pduel_
,
idx_
<<
16
);
YGO_SetResponsei
(
pduel_
,
idx_
<<
16
);
}
else
if
(
cmd
==
'c'
)
{
}
else
if
(
act
==
ActionAct
::
SpSummon
)
{
uint32_t
idx_
=
idx
-
spsummon_offset
;
uint32_t
idx_
=
idx
-
spsummon_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
1
);
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
1
);
}
else
if
(
cmd
==
'r'
)
{
}
else
if
(
act
==
ActionAct
::
Repo
)
{
uint32_t
idx_
=
idx
-
repos_offset
;
uint32_t
idx_
=
idx
-
repos_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
2
);
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
2
);
}
else
if
(
cmd
==
'm'
)
{
}
else
if
(
act
==
ActionAct
::
MSet
)
{
uint32_t
idx_
=
idx
-
mset_offset
;
uint32_t
idx_
=
idx
-
mset_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
3
);
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
3
);
}
else
if
(
cmd
==
't'
)
{
}
else
if
(
act
==
ActionAct
::
Set
)
{
uint32_t
idx_
=
idx
-
set_offset
;
uint32_t
idx_
=
idx
-
set_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
4
);
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
4
);
}
else
if
(
cmd
==
'v'
)
{
}
else
if
(
act
==
ActionAct
::
Activate
)
{
uint32_t
idx_
=
idx
-
activate_offset
;
uint32_t
idx_
=
idx
-
activate_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
5
);
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
5
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option: "
+
option
);
}
}
}
}
};
};
}
else
if
(
msg_
==
MSG_SELECT_PLACE
)
{
}
else
if
(
msg_
==
MSG_SELECT_PLACE
||
msg_
==
MSG_SELECT_DISFIELD
)
{
// TODO(1): add card informaton to select place
auto
player
=
read_u8
();
auto
player
=
read_u8
();
auto
count
=
read_u8
();
auto
count
=
read_u8
();
if
(
count
==
0
)
{
if
(
count
==
0
)
{
count
=
1
;
count
=
1
;
}
}
auto
flag
=
read_u32
();
if
(
count
!=
1
)
{
options_
=
flag_to_usable_cardspecs
(
flag
);
auto
s
=
fmt
::
format
(
"Select place count {} not implemented for {}"
,
if
(
verbose_
)
{
count
,
msg_
==
MSG_SELECT_PLACE
?
"place"
:
"disfield"
);
std
::
string
specs_str
=
options_
[
0
];
throw
std
::
runtime_error
(
s
);
for
(
int
i
=
1
;
i
<
options_
.
size
();
++
i
)
{
specs_str
+=
", "
+
options_
[
i
];
}
if
(
count
==
1
)
{
players_
[
player
]
->
notify
(
"Select place for card, one of "
+
specs_str
+
"."
);
}
else
{
players_
[
player
]
->
notify
(
"Select "
+
std
::
to_string
(
count
)
+
" places for card, from "
+
specs_str
+
"."
);
}
}
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
std
::
string
spec
=
options_
[
idx
];
auto
plr
=
player
;
if
(
spec
[
0
]
==
'o'
)
{
plr
=
1
-
player
;
spec
=
spec
.
substr
(
1
);
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
resp_buf_
[
0
]
=
plr
;
resp_buf_
[
1
]
=
loc
;
resp_buf_
[
2
]
=
seq
;
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
};
}
else
if
(
msg_
==
MSG_SELECT_DISFIELD
)
{
auto
player
=
read_u8
();
auto
count
=
read_u8
();
if
(
count
==
0
)
{
count
=
1
;
}
}
auto
flag
=
read_u32
();
auto
flag
=
read_u32
();
options_
=
flag_to_usable_cardspec
s
(
flag
);
auto
places
=
flag_to_usable_place
s
(
flag
);
if
(
verbose_
)
{
if
(
verbose_
)
{
std
::
string
specs_str
=
options_
[
0
];
auto
place_s
=
msg_
==
MSG_SELECT_PLACE
?
"place"
:
"disfield"
;
for
(
int
i
=
1
;
i
<
options_
.
size
();
++
i
)
{
auto
s
=
fmt
::
format
(
"Select {} for card, one of:"
,
place_s
);
specs_str
+=
", "
+
options_
[
i
];
players_
[
player
]
->
notify
(
s
);
}
}
if
(
count
==
1
)
{
for
(
int
i
=
0
;
i
<
places
.
size
();
++
i
)
{
players_
[
player
]
->
notify
(
"Select place for card, one of "
+
legal_actions_
.
push_back
(
LegalAction
::
place
(
places
[
i
]));
specs_str
+
"."
);
if
(
verbose_
)
{
}
else
{
auto
s
=
fmt
::
format
(
"{}: {}"
,
i
+
1
,
action_place_to_string
(
places
[
i
]));
throw
std
::
runtime_error
(
"Select disfield count "
+
players_
[
player
]
->
notify
(
s
);
std
::
to_string
(
count
)
+
" not implemented"
);
}
}
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
auto
place
=
legal_actions_
[
idx
].
place_
;
std
::
string
spec
=
options_
[
idx
];
int
i
=
static_cast
<
int
>
(
place
);
auto
plr
=
player
;
uint8_t
plr
=
player
;
if
(
spec
[
0
]
==
'o'
)
{
uint8_t
loc
;
uint8_t
seq
;
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
MZone7
))
{
loc
=
LOCATION_MZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
MZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
SZone8
))
{
loc
=
LOCATION_SZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
SZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpMZone7
))
{
plr
=
1
-
player
;
plr
=
1
-
player
;
spec
=
spec
.
substr
(
1
);
loc
=
LOCATION_MZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpSZone8
))
{
plr
=
1
-
player
;
loc
=
LOCATION_SZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
);
}
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
resp_buf_
[
0
]
=
plr
;
resp_buf_
[
0
]
=
plr
;
resp_buf_
[
1
]
=
loc
;
resp_buf_
[
1
]
=
loc
;
resp_buf_
[
2
]
=
seq
;
resp_buf_
[
2
]
=
seq
;
...
@@ -4620,7 +4856,7 @@ private:
...
@@ -4620,7 +4856,7 @@ private:
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
// options_.push_back(spec);
}
}
// TODO: implement action
// TODO
(2)
: implement action
n_counters_
=
count
;
n_counters_
=
count
;
uint16_t
resp1
=
static_cast
<
uint16_t
>
(
std
::
min
(
counter_count
,
counters
[
0
]));
uint16_t
resp1
=
static_cast
<
uint16_t
>
(
std
::
min
(
counter_count
,
counters
[
0
]));
memcpy
(
resp_buf_
,
&
resp1
,
2
);
memcpy
(
resp_buf_
,
&
resp1
,
2
);
...
@@ -4644,19 +4880,15 @@ private:
...
@@ -4644,19 +4880,15 @@ private:
" not implemented for announce number"
);
" not implemented for announce number"
);
}
}
numbers
.
push_back
(
number
);
numbers
.
push_back
(
number
);
options_
.
push_back
(
std
::
string
(
1
,
'0'
+
number
));
legal_actions_
.
push_back
(
LegalAction
::
number
(
number
));
}
}
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
std
::
string
str
=
"Select a number, one of: ["
;
std
::
string
str
=
"Select a number, one of:"
;
pl
->
notify
(
str
);
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
str
+=
std
::
to_string
(
numbers
[
i
]);
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
+
1
,
numbers
[
i
]));
if
(
i
<
count
-
1
)
{
str
+=
", "
;
}
}
}
str
+=
"]"
;
pl
->
notify
(
str
);
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
...
@@ -4675,7 +4907,7 @@ private:
...
@@ -4675,7 +4907,7 @@ private:
attrs
.
push_back
(
i
+
1
);
attrs
.
push_back
(
i
+
1
);
}
}
}
}
// TODO(2): implement action
if
(
count
!=
1
)
{
if
(
count
!=
1
)
{
throw
std
::
runtime_error
(
"Announce attrib count "
+
throw
std
::
runtime_error
(
"Announce attrib count "
+
std
::
to_string
(
count
)
+
" not implemented"
);
std
::
to_string
(
count
)
+
" not implemented"
);
...
@@ -4686,40 +4918,28 @@ private:
...
@@ -4686,40 +4918,28 @@ private:
pl
->
notify
(
"Select "
+
std
::
to_string
(
count
)
+
pl
->
notify
(
"Select "
+
std
::
to_string
(
count
)
+
" attributes separated by spaces:"
);
" attributes separated by spaces:"
);
for
(
int
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
pl
->
notify
(
std
::
to_string
(
attrs
[
i
])
+
": "
+
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
+
1
,
attribute_to_string
(
1
<<
(
attrs
[
i
]
-
1
))));
attribute_to_string
(
1
<<
(
attrs
[
i
]
-
1
)));
}
}
}
}
auto
combs
=
combinations
(
attrs
.
size
(),
count
);
// auto combs = combinations(attrs.size(), count);
for
(
const
auto
&
comb
:
combs
)
{
for
(
int
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
std
::
string
option
=
""
;
legal_actions_
.
push_back
(
LegalAction
::
attribute
(
1
<<
(
attrs
[
i
]
-
1
)));
for
(
int
j
=
0
;
j
<
count
;
++
j
)
{
option
+=
std
::
to_string
(
attrs
[
comb
[
j
]]);
if
(
j
<
count
-
1
)
{
option
+=
" "
;
}
}
options_
.
push_back
(
option
);
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
uint32_t
resp
=
0
;
uint32_t
resp
=
0
;
int
i
=
0
;
resp
|=
action
.
attribute_
;
while
(
i
<
option
.
size
())
{
resp
|=
1
<<
(
option
[
i
]
-
'1'
);
i
+=
2
;
}
YGO_SetResponsei
(
pduel_
,
resp
);
YGO_SetResponsei
(
pduel_
,
resp
);
};
};
}
else
if
(
msg_
==
MSG_SELECT_POSITION
)
{
}
else
if
(
msg_
==
MSG_SELECT_POSITION
)
{
// TODO: add card as feature
auto
player
=
read_u8
();
auto
player
=
read_u8
();
auto
code
=
read_u32
();
auto
code
=
read_u32
();
auto
valid_pos
=
read_u8
();
auto
valid_pos
=
read_u8
();
CardId
cid
=
c_get_card_id
(
code
);
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
...
@@ -4727,25 +4947,25 @@ private:
...
@@ -4727,25 +4947,25 @@ private:
pl
->
notify
(
"Select position for "
+
card
.
name_
+
":"
);
pl
->
notify
(
"Select position for "
+
card
.
name_
+
":"
);
}
}
std
::
vector
<
uint8_t
>
positions
;
int
i
=
1
;
for
(
auto
pos
:
{
POS_FACEUP_ATTACK
,
POS_FACEDOWN_ATTACK
,
for
(
auto
pos
:
{
POS_FACEUP_ATTACK
,
POS_FACEDOWN_ATTACK
,
POS_FACEUP_DEFENSE
,
POS_FACEDOWN_DEFENSE
})
{
POS_FACEUP_DEFENSE
,
POS_FACEDOWN_DEFENSE
})
{
if
(
valid_pos
&
pos
)
{
if
(
valid_pos
&
pos
)
{
positions
.
push_back
(
pos
);
LegalAction
la
;
options_
.
push_back
(
std
::
to_string
(
i
));
la
.
cid_
=
cid
;
la
.
position_
=
pos
;
legal_actions_
.
push_back
(
la
);
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
auto
pl
=
players_
[
player
];
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
,
position_to_string
(
pos
)));
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
cmd_idx
,
position_to_string
(
pos
)));
}
}
}
}
i
++
;
}
}
to_play_
=
player
;
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
callback_
=
[
this
](
int
idx
)
{
uint8_t
pos
=
options_
[
idx
][
0
]
-
'1'
;
uint8_t
pos
=
legal_actions_
[
idx
].
position_
;
YGO_SetResponsei
(
pduel_
,
1
<<
pos
);
YGO_SetResponsei
(
pduel_
,
pos
);
};
};
}
else
{
}
else
{
show_deck
(
0
);
show_deck
(
0
);
...
@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
...
@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
}
// namespace ygopro
}
// namespace ygopro
template
<>
struct
fmt
::
formatter
<
ygopro
::
LegalAction
>:
formatter
<
string_view
>
{
// Format the LegalAction object
template
<
typename
FormatContext
>
auto
format
(
const
ygopro
::
LegalAction
&
action
,
FormatContext
&
ctx
)
const
{
std
::
stringstream
ss
;
ss
<<
"{"
;
if
(
!
action
.
spec_
.
empty
())
{
ss
<<
"spec='"
<<
action
.
spec_
<<
"', "
;
}
if
(
action
.
cid_
!=
0
)
{
ss
<<
"cid="
<<
action
.
cid_
<<
", "
;
}
if
(
action
.
act_
!=
ygopro
::
ActionAct
::
None
)
{
ss
<<
"act="
<<
ygopro
::
action_act_to_string
(
action
.
act_
)
<<
", "
;
}
if
(
action
.
phase_
!=
ygopro
::
ActionPhase
::
None
)
{
ss
<<
"phase="
<<
ygopro
::
action_phase_to_string
(
action
.
phase_
)
<<
", "
;
}
if
(
action
.
finish_
)
{
ss
<<
"finish=true, "
;
}
if
(
action
.
position_
!=
0
)
{
ss
<<
"position="
<<
ygopro
::
position_to_string
(
action
.
position_
)
<<
", "
;
}
if
(
action
.
effect_
!=
-
1
)
{
ss
<<
"effect="
<<
action
.
effect_
<<
", "
;
}
if
(
action
.
number_
!=
0
)
{
ss
<<
"number="
<<
int
(
action
.
number_
)
<<
", "
;
}
if
(
action
.
place_
!=
ygopro
::
ActionPlace
::
None
)
{
ss
<<
"place="
<<
ygopro
::
action_place_to_string
(
action
.
place_
)
<<
", "
;
}
if
(
action
.
attribute_
!=
0
)
{
ss
<<
"attribute="
<<
ygopro
::
attribute_to_string
(
action
.
attribute_
)
<<
", "
;
}
std
::
string
s
=
ss
.
str
();
if
(
s
.
back
()
==
' '
)
{
s
.
pop_back
();
s
.
pop_back
();
}
s
.
push_back
(
'}'
);
return
format_to
(
ctx
.
out
(),
"{}"
,
s
);
}
};
#endif // YGOENV_YGOPRO_YGOPRO_H_
#endif // YGOENV_YGOPRO_YGOPRO_H_
ygoenv/ygoenv/ygopro0/__init__.py
0 → 100644
View file @
04e61b91
from
ygoenv.python.api
import
py_env
from
.ygopro0_ygoenv
import
(
_YGOPro0EnvPool
,
_YGOPro0EnvSpec
,
init_module
,
)
(
YGOPro0EnvSpec
,
YGOPro0DMEnvPool
,
YGOPro0GymEnvPool
,
YGOPro0GymnasiumEnvPool
,
)
=
py_env
(
_YGOPro0EnvSpec
,
_YGOPro0EnvPool
)
__all__
=
[
"YGOPro0EnvSpec"
,
"YGOPro0DMEnvPool"
,
"YGOPro0GymEnvPool"
,
"YGOPro0GymnasiumEnvPool"
,
]
ygoenv/ygoenv/ygopro0/registration.py
0 → 100644
View file @
04e61b91
from
ygoenv.registration
import
register
register
(
task_id
=
"YGOPro-v0"
,
import_path
=
"ygoenv.ygopro0"
,
spec_cls
=
"YGOPro0EnvSpec"
,
dm_cls
=
"YGOPro0DMEnvPool"
,
gym_cls
=
"YGOPro0GymEnvPool"
,
gymnasium_cls
=
"YGOPro0GymnasiumEnvPool"
,
)
ygoenv/ygoenv/ygopro0/ygopro.cpp
0 → 100644
View file @
04e61b91
#include "ygoenv/ygopro0/ygopro.h"
#include "ygoenv/core/py_envpool.h"
using
YGOPro0EnvSpec
=
PyEnvSpec
<
ygopro0
::
YGOProEnvSpec
>
;
using
YGOPro0EnvPool
=
PyEnvPool
<
ygopro0
::
YGOProEnvPool
>
;
PYBIND11_MODULE
(
ygopro0_ygoenv
,
m
)
{
REGISTER
(
m
,
YGOPro0EnvSpec
,
YGOPro0EnvPool
)
m
.
def
(
"init_module"
,
&
ygopro0
::
init_module
);
}
ygoenv/ygoenv/ygopro0/ygopro.h
0 → 100644
View file @
04e61b91
This source diff could not be displayed because it is too large. You can
view the blob
instead.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment