Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
04e61b91
Commit
04e61b91
authored
May 21, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add YGOPro-v1
parent
f6139c17
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
6345 additions
and
1358 deletions
+6345
-1358
docs/action.md
docs/action.md
+29
-0
docs/feature_engineering.md
docs/feature_engineering.md
+23
-31
scripts/battle.py
scripts/battle.py
+1
-1
scripts/eval.py
scripts/eval.py
+1
-2
xmake.lua
xmake.lua
+18
-1
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+491
-163
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+0
-494
ygoai/utils.py
ygoai/utils.py
+6
-1
ygoenv/ygoenv/core/BS_thread_pool.h
ygoenv/ygoenv/core/BS_thread_pool.h
+0
-0
ygoenv/ygoenv/entry.py
ygoenv/ygoenv/entry.py
+4
-1
ygoenv/ygoenv/ygopro/registration.py
ygoenv/ygoenv/ygopro/registration.py
+1
-1
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+931
-663
ygoenv/ygoenv/ygopro0/__init__.py
ygoenv/ygoenv/ygopro0/__init__.py
+22
-0
ygoenv/ygoenv/ygopro0/registration.py
ygoenv/ygoenv/ygopro0/registration.py
+10
-0
ygoenv/ygoenv/ygopro0/ygopro.cpp
ygoenv/ygoenv/ygopro0/ygopro.cpp
+11
-0
ygoenv/ygoenv/ygopro0/ygopro.h
ygoenv/ygoenv/ygopro0/ygopro.h
+4797
-0
No files found.
docs/action.md
0 → 100644
View file @
04e61b91
# Action
## Types
-
Set + card
-
Reposition + card
-
Special summon + card
-
Summon Face-up Attack + card
-
Summon Face-down Defense + card
-
Attack + card
-
DirectAttack + card
-
Activate + card + effect
-
Cancel
-
Switch + phase
-
SelectPosition + card + position
-
AnnounceNumber + card + effect + number
-
SelectPlace + card + place
-
AnnounceAttrib + card + effect + attrib
## Effect
### MSG_SELECT_BATTLECMD | MSG_SELECT_IDLECMD | MSG_SELECT_CHAIN | MSG_SELECT_EFFECTYN
-
desc == 0: default effect of card
-
desc < LIMIT: system string
-
desc > LIMIT: card + effect
### MSG_SELECT_OPTION | MSG_SELECT_YESNO
-
desc == 0: error
-
desc < LIMIT: system string
-
desc > LIMIT: card + effect
docs/feature_engineering.md
View file @
04e61b91
...
...
@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`.
## Legal Actions
-
0,1: spec index, uint16 -> 2 uint8
-
2: msg, discrete, 0: N/A, 1+: same as msg2str (15)
-
3: act, discrete (11)
-
0: spec index
-
1,2: code, uint16 -> 2 uint8
-
3: msg, discrete, 0: N/A, 1+: same as msg2str (15)
-
4: act, discrete (11)
-
N/A
-
t: Set
-
r: Reposition
-
c: Special Summon
-
s: Summon Face-up Attack
-
m: Summon Face-down Defense
-
a: Attack
-
v: Activate
-
v2: Activate the second effect
-
v3: Activate the third effect
-
v4: Activate the fourth effect
-
4: yes/no, discrete (3)
-
Set
-
Reposition
-
Special Summon
-
Summon Face-up Attack
-
Summon Face-down Defense
-
Attack
-
DirectAttack
-
Activate
-
Cancel
-
5: finish, discrete (2)
-
N/A
-
Yes
-
No
-
5
: phase, discrete (4)
-
Finish
-
6: effect, discrete, 0: N/A
-
7
: phase, discrete (4)
-
N/A
-
Battle (b)
-
Main Phase 2 (m)
-
End Phase (e)
-
6: cancel, discrete (2)
-
N/A
-
Cancel
-
7: finish, discrete (2)
-
N/A
-
Finish
-
8: position, discrete, 0: N/A, same as position2str
-
9: option, discrete, 0: N/A
-
10: number, discrete, 0: N/A
-
11: place, discrete
-
9: number, discrete, 0: N/A
-
10: place, discrete
-
0: N/A
-
1-7: m
-
8-15: s
-
16-22: om
-
23-30: os
-
1
2
: attribute, discrete, 0: N/A, same as attribute2id
-
1
1
: attribute, discrete, 0: N/A, same as attribute2id
## History Actions
-
0,1: card id, uint16 -> 2 uint8
-
2-12 same as legal actions
-
13: player, discrete, 0: me, 1: oppo
-
14: turn, discrete, trunc to 3
-
2-11 same as legal actions
-
12: turn, discrete, trunc to 3
-
13: phase, discrete (10)
scripts/battle.py
View file @
04e61b91
...
...
@@ -18,7 +18,7 @@ import flax
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent
2
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
@
dataclass
...
...
scripts/eval.py
View file @
04e61b91
...
...
@@ -135,7 +135,7 @@ if __name__ == "__main__":
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent
2
import
RNNAgent
from
ygoai.rl.jax.agent
import
RNNAgent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
...
...
@@ -168,7 +168,6 @@ if __name__ == "__main__":
obs
,
infos
=
envs
.
reset
()
print
(
obs
)
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
...
...
xmake.lua
View file @
04e61b91
...
...
@@ -8,6 +8,24 @@ add_requires(
"sqlitecpp 3.2.1"
)
target
(
"ygopro0_ygoenv"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/ygopro0/*.cpp"
)
add_packages
(
"pybind11"
,
"fmt"
,
"glog"
,
"concurrentqueue"
,
"sqlitecpp"
,
"unordered_dense"
,
"ygopro-core"
)
set_languages
(
"c++17"
)
if
is_mode
(
"release"
)
then
set_policy
(
"build.optimization.lto"
,
true
)
add_cxxflags
(
"-march=native"
)
end
add_includedirs
(
"ygoenv"
)
after_build
(
function
(
target
)
local
install_target
=
"$(projectdir)/ygoenv/ygoenv/ygopro0"
os
.
cp
(
target
:
targetfile
(),
install_target
)
print
(
"Copy target to "
..
install_target
)
end
)
target
(
"ygopro_ygoenv"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/ygopro/*.cpp"
)
...
...
@@ -25,7 +43,6 @@ target("ygopro_ygoenv")
print
(
"Copy target to "
..
install_target
)
end
)
target
(
"edopro_ygoenv"
)
add_rules
(
"python.library"
)
add_files
(
"ygoenv/ygoenv/edopro/*.cpp"
)
...
...
ygoai/rl/jax/agent.py
View file @
04e61b91
from
typing
import
Tuple
,
Union
,
Optional
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
DecoderLayer
,
PositionalEncoding
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
...
...
@@ -14,11 +16,18 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
...
...
@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module):
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
...
...
@@ -38,18 +46,165 @@ class ActionEncoder(nn.Module):
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
return
jnp
.
concatenate
([
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
],
axis
=-
1
)
xs
=
[
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
ActionEncoderV1
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
div
=
8
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
10
,
c
//
div
)(
x
[:,
:,
1
])
x_a_finish
=
embed
(
3
,
c
//
div
//
2
)(
x
[:,
:,
2
])
x_a_effect
=
embed
(
256
,
c
//
div
*
2
)(
x
[:,
:,
3
])
x_a_phase
=
embed
(
4
,
c
//
div
//
2
)(
x
[:,
:,
4
])
x_a_position
=
embed
(
9
,
c
//
div
)(
x
[:,
:,
5
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
6
])
x_a_place
=
embed
(
31
,
c
//
div
)(
x
[:,
:,
7
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
8
])
xs
=
[
x_a_msg
,
x_a_act
,
x_a_finish
,
x_a_effect
,
x_a_phase
,
x_a_position
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x_loc
=
x1
[:,
:,
0
]
x_seq
=
x1
[:,
:,
1
]
if
self
.
version
==
0
:
x_id
=
mlp
(
(
c
,
c
//
4
),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_race
=
embed
(
27
,
c
//
16
)(
x1
[:,
:,
6
])
x_level
=
embed
(
14
,
c
//
16
)(
x1
[:,
:,
7
])
x_counter
=
embed
(
16
,
c
//
16
)(
x1
[:,
:,
8
])
x_negated
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
9
])
x_atk
=
num_transform
(
x2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
if
self
.
version
==
0
:
x_f
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
x_cards
=
jnp
.
concatenate
([
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
x_id
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
0
:
2
]))
x_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
2
:
4
]))
x_turn
=
embed
(
20
,
c
//
8
)(
x2
[:,
0
])
x_phase
=
embed
(
11
,
c
//
8
)(
x2
[:,
1
])
x_if_first
=
embed
(
2
,
c
//
8
)(
x2
[:,
2
])
x_is_my_turn
=
embed
(
2
,
c
//
8
)(
x2
[:,
3
])
x_cs
=
count_embed
(
x3
)
.
reshape
((
batch_size
,
-
1
))
x_my_hand_c
=
hand_count_embed
(
x3
[:,
1
])
x_op_hand_c
=
hand_count_embed
(
x3
[:,
8
])
x
=
jnp
.
concatenate
([
x_lp
,
x_oppo_lp
,
x_turn
,
x_phase
,
x_if_first
,
x_is_my_turn
,
x_cs
,
x_my_hand_c
,
x_op_hand_c
],
axis
=-
1
)
x
=
layer_norm
()(
x
)
return
x
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
num_card_layers
:
int
=
2
num_action_layers
:
int
=
2
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
):
...
...
@@ -62,154 +217,182 @@ class Encoder(nn.Module):
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
False
,
use_bias
=
Fals
e
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
Tru
e
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
MLP
((
c
//
8
,),
last_lin
=
False
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
batch_size
=
x_cards
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
x_cards_1
=
x_cards
[:,
:,
:
12
]
.
astype
(
jnp
.
int32
)
x_cards_2
=
x_cards
[:,
:,
12
:]
.
astype
(
jnp
.
float32
)
x_id
=
decode_id
(
x_cards_1
[:,
:,
:
2
])
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
x_id
=
MLP
(
(
c
,
c
//
4
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
x_loc
=
x_cards_1
[:,
:,
2
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
x_seq
=
x_cards_1
[:,
:,
3
]
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
x_owner
=
embed
(
2
,
c
//
16
)(
x_cards_1
[:,
:,
4
])
x_position
=
embed
(
9
,
c
//
16
)(
x_cards_1
[:,
:,
5
])
x_overley
=
embed
(
2
,
c
//
16
)(
x_cards_1
[:,
:,
6
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x_cards_1
[:,
:,
7
])
x_race
=
embed
(
27
,
c
//
16
)(
x_cards_1
[:,
:,
8
])
x_level
=
embed
(
14
,
c
//
16
)(
x_cards_1
[:,
:,
9
])
x_counter
=
embed
(
16
,
c
//
16
)(
x_cards_1
[:,
:,
10
])
x_negated
=
embed
(
3
,
c
//
16
)(
x_cards_1
[:,
:,
11
])
x_atk
=
num_transform
(
x_cards_2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x_cards_2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x_cards_2
[:,
:,
4
:])
# Cards
f_cards
,
c_mask
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_id
,
x_cards
[:,
:,
2
:])
g_card_embed
=
self
.
param
(
'g_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
if
self
.
card_mask
:
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
c_mask
=
None
x_feat
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_feat
=
layer_norm
()(
x_feat
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_feat
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_card_layers
):
f_cards
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
)
for
_
in
range
(
self
.
num_layers
):
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
if
self
.
version
==
0
:
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
else
:
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
1
:
3
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_id
)
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
3
:
12
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
12
])
x_h_a_phase
=
embed
(
12
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_feats
.
extend
([
x_h_id
,
x_h_a_turn
,
x_h_a_phase
])
x_h_a_feats
=
jnp
.
concatenate
(
x_h_a_feats
,
axis
=-
1
)
x_h_a_feats
=
layer_norm
()(
x_h_a_feats
)
x_h_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_h_a_feats
)
if
self
.
noam
:
f_h_actions
=
LlamaEncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
rope
=
True
,
n_positions
=
64
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
else
:
x_h_a_feats
=
PositionalEncoding
()(
x_h_a_feats
)
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
# Actions
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
na_card_embed
=
self
.
param
(
'na_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
],
axis
=
1
)
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
f_cards
=
layer_norm
()(
f_cards
)
x_global_1
=
x_global
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x_g_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x_global_1
[:,
0
:
2
]))
x_g_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x_global_1
[:,
2
:
4
]))
x_global_2
=
x_global
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x_g_turn
=
embed
(
20
,
c
//
8
)(
x_global_2
[:,
0
])
x_g_phase
=
embed
(
11
,
c
//
8
)(
x_global_2
[:,
1
])
x_g_if_first
=
embed
(
2
,
c
//
8
)(
x_global_2
[:,
2
])
x_g_is_my_turn
=
embed
(
2
,
c
//
8
)(
x_global_2
[:,
3
])
x_global_3
=
x_global
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_g_cs
=
count_embed
(
x_global_3
)
.
reshape
((
batch_size
,
-
1
))
x_g_my_hand_c
=
hand_count_embed
(
x_global_3
[:,
1
])
x_g_op_hand_c
=
hand_count_embed
(
x_global_3
[:,
8
])
x_global
=
jnp
.
concatenate
([
x_g_lp
,
x_g_oppo_lp
,
x_g_turn
,
x_g_phase
,
x_g_if_first
,
x_g_is_my_turn
,
x_g_cs
,
x_g_my_hand_c
,
x_g_op_hand_c
],
axis
=-
1
)
x_global
=
layer_norm
()(
x_global
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
)(
f_global
)
f_global
=
layer_norm
()(
f_global
)
f_cards
=
f_cards
+
jnp
.
expand_dims
(
f_global
,
1
)
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
f_a_cards
+
fc_layer
(
c
)(
layer_norm
()(
f_a_cards
))
x_a_feats
=
action_encoder
(
x_actions
[
...
,
2
:])
f_actions
=
f_a_cards
+
layer_norm
()(
x_a_feats
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
for
_
in
range
(
self
.
num_action_layers
):
f_actions
=
DecoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
f_cards
,
tgt_key_padding_mask
=
a_mask
,
memory_key_padding_mask
=
c_mask
)
x_h_actions
=
x
[
'h_actions_'
]
.
astype
(
jnp
.
int32
)
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
id_embed
(
x_h_id
))
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
2
:])
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
x_h_a_feats
)
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_action_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
for
_
in
range
(
self
.
num_action_layers
):
f_actions
=
DecoderLayer
(
num_heads
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
f_h_actions
,
tgt_key_padding_mask
=
a_mask
,
memory_key_padding_mask
=
h_mask
)
f_actions
=
layer_norm
()(
f_actions
)
f_s_cards_global
=
f_cards
.
mean
(
axis
=
1
)
c_mask
=
1
-
a_mask
[:,
:,
None
]
.
astype
(
f_actions
.
dtype
)
f_s_actions_ha
=
(
f_actions
*
c_mask
)
.
sum
(
axis
=
1
)
/
c_mask
.
sum
(
axis
=
1
)
f_state
=
jnp
.
concatenate
([
f_s_cards_global
,
f_s_actions_ha
],
axis
=-
1
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
if
self
.
version
==
0
:
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
if
not
self
.
use_history
:
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
else
:
spec_index
=
x_actions
[
...
,
0
]
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
x_a_id
=
decode_id
(
x_actions
[
...
,
1
:
3
])
x_a_id
=
id_embed
(
x_a_id
)
if
self
.
freeze_id
:
x_a_id
=
jax
.
lax
.
stop_gradient
(
x_a_id
)
x_a_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_a_id
)
x_a_feats
=
action_encoder
(
x_actions
[
...
,
3
:])
x_a_feats
.
append
(
x_a_id
)
x_a_feats
=
jnp
.
concatenate
(
x_a_feats
,
axis
=-
1
)
x_a_feats
=
layer_norm
()(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
f_actions
=
jax
.
nn
.
silu
(
f_a_cards
)
*
x_a_feats
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
f_actions
=
x_a_feats
+
f_actions
a_mask
=
x_actions
[:,
:,
3
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
f_actions_g
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions_g
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
if
self
.
use_history
:
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
else
:
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
...
...
@@ -219,54 +402,199 @@ class Actor(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_actions
,
mask
):
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
f_state
=
mlp
((
c
,),
use_bias
=
True
)(
f_state
)
logits
=
jnp
.
einsum
(
'bc,bnc->bn'
,
f_state
,
f_actions
)
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
EncoderLayer
(
num_heads
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
src_key_padding_mask
=
mask
)
logits
=
mlp
((
c
//
4
,
1
),
use_bias
=
True
)(
f_actions
)
logits
=
logits
[
...
,
0
]
f_actions
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
Critic
(
nn
.
Module
):
channels
:
int
=
128
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
):
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
x
=
MLP
((
c
//
2
,
1
),
use_bias
=
True
)(
f_state
)
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
class
PPOAgent
(
nn
.
Module
):
channels
:
int
=
128
num_card_layers
:
int
=
2
num_action_layers
:
int
=
2
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
@
dataclass
class
ModelArgs
:
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
version
:
int
=
0
"""the version of the environment and the agent"""
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
2
num_channels
:
int
=
128
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
encoder
=
Encoder
(
channels
=
self
.
channels
,
num_card_layers
=
self
.
num_card_layers
,
num_action_layers
=
self
.
num_action_layers
,
channels
=
c
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
version
=
self
.
version
,
)
actor
=
Actor
(
channels
=
self
.
channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
self
.
channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
logits
=
actor
(
f_actions
,
mask
)
value
=
critic
(
f_state
)
return
logits
,
value
,
valid
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
not
None
:
assert
switch_or_main
is
not
None
else
:
assert
not
multi_step
if
multi_step
:
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/agent2.py
deleted
100644 → 0
View file @
f6139c17
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
make_bin_params
,
bytes_to_bin
,
decode_id
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init1
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
div
=
8
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
x_a_phase
=
embed
(
4
,
c
//
div
)(
x
[:,
:,
3
])
x_a_cancel
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
4
])
x_a_finish
=
embed
(
3
,
c
//
div
//
2
)(
x
[:,
:,
5
])
x_a_position
=
embed
(
9
,
c
//
div
//
2
)(
x
[:,
:,
6
])
x_a_option
=
embed
(
6
,
c
//
div
//
2
)(
x
[:,
:,
7
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
xs
=
[
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x_id
=
mlp
(
(
c
,
c
//
4
),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
layer_norm
()(
x_id
)
x_loc
=
x1
[:,
:,
0
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
f_loc
=
layer_norm
()(
embed
(
9
,
c
)(
x_loc
))
x_seq
=
x1
[:,
:,
1
]
f_seq
=
layer_norm
()(
embed
(
76
,
c
)(
x_seq
))
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_race
=
embed
(
27
,
c
//
16
)(
x1
[:,
:,
6
])
x_level
=
embed
(
14
,
c
//
16
)(
x1
[:,
:,
7
])
x_counter
=
embed
(
16
,
c
//
16
)(
x1
[:,
:,
8
])
x_negated
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
9
])
x_atk
=
num_transform
(
x2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
x_f
=
jnp
.
concatenate
([
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
0
:
2
]))
x_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
2
:
4
]))
x_turn
=
embed
(
20
,
c
//
8
)(
x2
[:,
0
])
x_phase
=
embed
(
11
,
c
//
8
)(
x2
[:,
1
])
x_if_first
=
embed
(
2
,
c
//
8
)(
x2
[:,
2
])
x_is_my_turn
=
embed
(
2
,
c
//
8
)(
x2
[:,
3
])
x_cs
=
count_embed
(
x3
)
.
reshape
((
batch_size
,
-
1
))
x_my_hand_c
=
hand_count_embed
(
x3
[:,
1
])
x_op_hand_c
=
hand_count_embed
(
x3
[:,
8
])
x
=
jnp
.
concatenate
([
x_lp
,
x_oppo_lp
,
x_turn
,
x_phase
,
x_if_first
,
x_is_my_turn
,
x_cs
,
x_my_hand_c
,
x_op_hand_c
],
axis
=-
1
)
x
=
layer_norm
()(
x
)
return
x
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
if
self
.
embedding_shape
is
None
:
n_embed
,
embed_dim
=
999
,
1024
elif
isinstance
(
self
.
embedding_shape
,
int
):
n_embed
,
embed_dim
=
self
.
embedding_shape
,
1024
else
:
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
batch_size
=
x_cards
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
# Cards
f_cards
,
c_mask
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_id
,
x_cards
[:,
:,
2
:])
g_card_embed
=
self
.
param
(
'g_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
if
self
.
card_mask
:
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
c_mask
=
None
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_layers
):
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
h_mask
=
x_h_actions
[:,
:,
2
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
# Actions
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
na_card_embed
=
self
.
param
(
'na_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
# State
if
not
self
.
use_history
:
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
class
Actor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
f_state
=
mlp
((
c
,),
use_bias
=
True
)(
f_state
)
logits
=
jnp
.
einsum
(
'bc,bnc->bn'
,
f_state
,
f_actions
)
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
Critic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
@
dataclass
class
ModelArgs
:
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
2
num_channels
:
int
=
128
rnn_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
c
=
self
.
num_channels
encoder
=
Encoder
(
channels
=
c
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
not
None
:
assert
switch_or_main
is
not
None
else
:
assert
not
multi_step
if
multi_step
:
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/utils.py
View file @
04e61b91
...
...
@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
raise
FileNotFoundError
(
f
"Token deck not found: {token_deck}"
)
decks
[
"_tokens"
]
=
str
(
token_deck
)
if
'YGOPro'
in
env_id
:
from
ygoenv.ygopro
import
init_module
if
env_id
==
'YGOPro-v1'
:
from
ygoenv.ygopro
import
init_module
elif
env_id
==
'YGOPro-v0'
:
from
ygoenv.ygopro0
import
init_module
else
:
raise
ValueError
(
f
"Unknown YGOPro environment: {env_id}"
)
elif
'EDOPro'
in
env_id
:
from
ygoenv.edopro
import
init_module
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
...
...
ygoenv/ygoenv/
ygopro
/BS_thread_pool.h
→
ygoenv/ygoenv/
core
/BS_thread_pool.h
View file @
04e61b91
File moved
ygoenv/ygoenv/entry.py
View file @
04e61b91
...
...
@@ -18,13 +18,16 @@ try:
except
ImportError
:
pass
try
:
import
ygoenv.ygopro0.registration
# noqa: F401
except
ImportError
:
pass
try
:
import
ygoenv.edopro.registration
# noqa: F401
except
ImportError
:
pass
try
:
import
ygoenv.dummy.registration
# noqa: F401
except
ImportError
:
...
...
ygoenv/ygoenv/ygopro/registration.py
View file @
04e61b91
from
ygoenv.registration
import
register
register
(
task_id
=
"YGOPro-v
0
"
,
task_id
=
"YGOPro-v
1
"
,
import_path
=
"ygoenv.ygopro"
,
spec_cls
=
"YGOProEnvSpec"
,
dm_cls
=
"YGOProDMEnvPool"
,
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
04e61b91
...
...
@@ -23,7 +23,7 @@
#include <ankerl/unordered_dense.h>
#include <unordered_set>
#include "BS_thread_pool.h"
#include "
ygoenv/core/
BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
...
...
@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) {
}
// system string
static
const
ankerl
::
unordered_dense
::
map
<
int
,
std
::
string
>
system_strings
=
{
static
const
std
::
map
<
int
,
std
::
string
>
system_strings
=
{
// announce type
{
1050
,
"Monster"
},
{
1051
,
"Spell"
},
{
1052
,
"Trap"
},
{
1054
,
"Normal"
},
{
1055
,
"Effect"
},
{
1056
,
"Fusion"
},
{
1057
,
"Ritual"
},
{
1058
,
"Trap Monsters"
},
{
1059
,
"Spirit"
},
{
1060
,
"Union"
},
{
1061
,
"Gemini"
},
{
1062
,
"Tuner"
},
{
1063
,
"Synchro"
},
{
1064
,
"Token"
},
{
1066
,
"Quick-Play"
},
{
1067
,
"Continuous"
},
{
1068
,
"Equip"
},
{
1069
,
"Field"
},
{
1070
,
"Counter"
},
{
1071
,
"Flip"
},
{
1072
,
"Toon"
},
{
1073
,
"Xyz"
},
{
1074
,
"Pendulum"
},
{
1075
,
"Special Summon"
},
{
1076
,
"Link"
},
{
1080
,
"(N/A)"
},
{
1081
,
"Extra Monster Zone"
},
// announce type end
// actions
{
1150
,
"Activate"
},
{
1151
,
"Normal Summon"
},
{
1152
,
"Special Summon"
},
{
1153
,
"Set"
},
{
1154
,
"Flip Summon"
},
{
1155
,
"To Defense"
},
{
1156
,
"To Attack"
},
{
1157
,
"Attack"
},
{
1158
,
"View"
},
{
1159
,
"S/T Set"
},
{
1160
,
"Put in Pendulum Zone"
},
{
1161
,
"Do Effect"
},
{
1162
,
"Reset Effect"
},
{
1163
,
"Pendulum Summon"
},
{
1164
,
"Synchro Summon"
},
{
1165
,
"Xyz Summon"
},
{
1166
,
"Link Summon"
},
{
1167
,
"Tribute Summon"
},
{
1168
,
"Ritual Summon"
},
{
1169
,
"Fusion Summon"
},
{
1190
,
"Add to hand"
},
{
1191
,
"Send to GY"
},
{
1192
,
"Banish"
},
{
1193
,
"Return to Deck"
},
// actions end
{
1
,
"Normal Summon"
},
{
30
,
"Replay rules apply. Continue this attack?"
},
{
31
,
"Attack directly with this monster?"
},
{
80
,
"Start Step of the Battle Phase."
},
{
81
,
"During the End Phase."
},
{
90
,
"Conduct this Normal Summon without Tributing?"
},
{
91
,
"Use additional Summon?"
},
{
92
,
"Tribute your opponent's monster?"
},
{
93
,
"Continue selecting Materials?"
},
{
94
,
"Activate this card's effect now?"
},
{
95
,
"Use the effect of [%ls]?"
},
{
96
,
"Use the effect of [%ls] to avoid destruction?"
},
{
97
,
"Place [%ls] to a Spell & Trap Zone?"
},
{
98
,
"Tribute a monster(s) your opponent controls?"
},
{
200
,
"From [%ls], activate [%ls]?"
},
{
203
,
"Chain another card or effect?"
},
{
210
,
"Continue selecting?"
},
{
218
,
"Pay LP by Effect of [%ls], instead?"
},
{
219
,
"Detach Xyz material by Effect of [%ls], instead?"
},
{
220
,
"Remove Counter(s) by Effect of [%ls], instead?"
},
{
221
,
"On [%ls], Activate Trigger Effect of [%ls]?"
},
{
222
,
"Activate Trigger Effect?"
},
{
221
,
"On [%ls], Activate Trigger Effect of [%ls]?"
},
{
1190
,
"Add to hand"
},
{
1192
,
"Banish"
},
{
1621
,
"Attack Negated"
},
{
1622
,
"[%ls] Missed timing"
}
};
...
...
@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) {
if
(
it
!=
system_strings
.
end
())
{
return
it
->
second
;
}
return
"system string "
+
std
::
to_string
(
desc
);
throw
std
::
runtime_error
(
fmt
::
format
(
"Cannot find system string: {}"
,
desc
));
// return "system string " + std::to_string(desc);
}
static
std
::
string
ltrim
(
std
::
string
s
)
{
...
...
@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) {
return
s
;
}
inline
std
::
vector
<
std
::
string
>
flag_to_usable_cardspecs
(
uint32_t
flag
,
bool
reverse
=
false
)
{
std
::
string
zone_names
[
4
]
=
{
"m"
,
"s"
,
"om"
,
"os"
};
std
::
vector
<
std
::
string
>
specs
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
uint32_t
value
=
(
flag
>>
(
j
*
8
))
&
0xff
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
bool
avail
=
(
value
&
(
1
<<
i
))
==
0
;
if
(
reverse
)
{
avail
=
!
avail
;
}
if
(
avail
)
{
specs
.
push_back
(
zone_names
[
j
]
+
std
::
to_string
(
i
+
1
));
}
}
}
return
specs
;
}
inline
std
::
string
ls_to_spec
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
)
{
std
::
string
spec
;
...
...
@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) {
loc
=
LOCATION_DECK
;
offset
=
0
;
}
else
{
throw
std
::
runtime_error
(
"Invalid location"
);
std
::
string
s
=
fmt
::
format
(
"Invalid spec {}"
,
spec
);
throw
std
::
runtime_error
(
s
);
}
int
end
=
offset
;
while
(
end
<
spec
.
size
()
&&
std
::
isdigit
(
spec
[
end
]))
{
...
...
@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) {
return
{
loc
,
seq
,
pos
};
}
inline
uint32_t
ls_to_spec_code
(
uint8_t
loc
,
uint8_t
seq
,
uint8_t
pos
,
bool
opponent
)
{
uint32_t
c
=
opponent
?
1
:
0
;
c
|=
(
loc
<<
8
);
c
|=
(
seq
<<
16
);
c
|=
(
pos
<<
24
);
return
c
;
}
inline
uint32_t
spec_to_code
(
const
std
::
string
&
spec
)
{
inline
std
::
tuple
<
uint8_t
,
uint8_t
,
uint8_t
,
uint8_t
>
spec_to_ls
(
uint8_t
player
,
const
std
::
string
spec
)
{
uint8_t
controller
=
player
;
int
offset
=
0
;
bool
opponent
=
false
;
if
(
spec
[
0
]
==
'o'
)
{
opponent
=
true
;
controller
=
1
-
player
;
offset
++
;
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
return
ls_to_spec_code
(
loc
,
seq
,
pos
,
opponent
)
;
return
{
controller
,
loc
,
seq
,
pos
}
;
}
inline
std
::
string
code_to_spec
(
uint32_t
spec_code
)
{
uint8_t
loc
=
(
spec_code
>>
8
)
&
0xff
;
uint8_t
seq
=
(
spec_code
>>
16
)
&
0xff
;
uint8_t
pos
=
(
spec_code
>>
24
)
&
0xff
;
bool
opponent
=
(
spec_code
&
0xff
)
==
1
;
return
ls_to_spec
(
loc
,
seq
,
pos
,
opponent
);
}
static
std
::
tuple
<
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>
,
std
::
vector
<
uint32
>>
read_decks
(
const
std
::
string
&
fp
)
{
std
::
ifstream
file
(
fp
);
...
...
@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \
return "unknown"; \
}
static
const
ankerl
::
unordered_dense
::
map
<
int
,
uint8_t
>
system_string2id
=
make_ids
(
system_strings
,
16
);
DEFINE_X_TO_ID_FUN
(
system_string_to_id
,
system_string2id
)
static
const
std
::
map
<
uint8_t
,
std
::
string
>
location2str
=
{
{
LOCATION_DECK
,
"Deck"
},
{
LOCATION_HAND
,
"Hand"
},
...
...
@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id =
DEFINE_X_TO_ID_FUN
(
msg_to_id
,
msg2id
)
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_act2id
=
make_ids
({
't'
,
'r'
,
'c'
,
's'
,
'm'
,
'a'
,
'v'
},
1
);
DEFINE_X_TO_ID_FUN
(
cmd_act_to_id
,
cmd_act2id
)
enum
class
ActionAct
{
None
,
Set
,
Repo
,
SpSummon
,
Summon
,
MSet
,
Attack
,
DirectAttack
,
Activate
,
Cancel
,
};
inline
std
::
string
action_act_to_string
(
ActionAct
act
)
{
switch
(
act
)
{
case
ActionAct
:
:
None
:
return
"None"
;
case
ActionAct
:
:
Set
:
return
"Set"
;
case
ActionAct
:
:
Repo
:
return
"Repo"
;
case
ActionAct
:
:
SpSummon
:
return
"SpSummon"
;
case
ActionAct
:
:
Summon
:
return
"Summon"
;
case
ActionAct
:
:
MSet
:
return
"MSet"
;
case
ActionAct
:
:
Attack
:
return
"Attack"
;
case
ActionAct
:
:
DirectAttack
:
return
"DirectAttack"
;
case
ActionAct
:
:
Activate
:
return
"Activate"
;
case
ActionAct
:
:
Cancel
:
return
"Cancel"
;
default:
return
"Unknown"
;
}
}
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_phase2id
=
make_ids
(
std
::
vector
<
char
>
({
'b'
,
'm'
,
'e'
}),
1
);
DEFINE_X_TO_ID_FUN
(
cmd_phase_to_id
,
cmd_phase2id
)
enum
class
ActionPhase
{
None
,
Battle
,
Main2
,
End
,
};
inline
std
::
string
action_phase_to_string
(
ActionPhase
phase
)
{
switch
(
phase
)
{
case
ActionPhase
:
:
None
:
return
"None"
;
case
ActionPhase
:
:
Battle
:
return
"Battle"
;
case
ActionPhase
:
:
Main2
:
return
"Main2"
;
case
ActionPhase
:
:
End
:
return
"End"
;
default:
return
"Unknown"
;
}
}
static
const
ankerl
::
unordered_dense
::
map
<
char
,
uint8_t
>
cmd_yesno2id
=
make_ids
(
std
::
vector
<
char
>
({
'y'
,
'n'
}),
1
);
DEFINE_X_TO_ID_FUN
(
cmd_yesno_to_id
,
cmd_yesno2id
)
enum
class
ActionPlace
{
None
,
MZone1
,
MZone2
,
MZone3
,
MZone4
,
MZone5
,
MZone6
,
MZone7
,
SZone1
,
SZone2
,
SZone3
,
SZone4
,
SZone5
,
SZone6
,
SZone7
,
SZone8
,
OpMZone1
,
OpMZone2
,
OpMZone3
,
OpMZone4
,
OpMZone5
,
OpMZone6
,
OpMZone7
,
OpSZone1
,
OpSZone2
,
OpSZone3
,
OpSZone4
,
OpSZone5
,
OpSZone6
,
OpSZone7
,
OpSZone8
,
};
static
const
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint8_t
>
cmd_place2id
=
make_ids
(
std
::
vector
<
std
::
string
>
(
{
"m1"
,
"m2"
,
"m3"
,
"m4"
,
"m5"
,
"m6"
,
"m7"
,
"s1"
,
"s2"
,
"s3"
,
"s4"
,
"s5"
,
"s6"
,
"s7"
,
"s8"
,
"om1"
,
"om2"
,
"om3"
,
"om4"
,
"om5"
,
"om6"
,
"om7"
,
"os1"
,
"os2"
,
"os3"
,
"os4"
,
"os5"
,
"os6"
,
"os7"
,
"os8"
}),
1
);
DEFINE_X_TO_ID_FUN
(
cmd_place_to_id
,
cmd_place2id
)
inline
std
::
vector
<
ActionPlace
>
flag_to_usable_places
(
uint32_t
flag
,
bool
reverse
=
false
)
{
std
::
vector
<
ActionPlace
>
places
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
uint32_t
value
=
(
flag
>>
(
j
*
8
))
&
0xff
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
bool
avail
=
(
value
&
(
1
<<
i
))
==
0
;
if
(
reverse
)
{
avail
=
!
avail
;
}
if
(
avail
)
{
ActionPlace
place
;
if
(
j
==
0
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
MZone1
));
}
else
if
(
j
==
1
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
SZone1
));
}
else
if
(
j
==
2
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
));
}
else
if
(
j
==
3
)
{
place
=
static_cast
<
ActionPlace
>
(
i
+
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
));
}
places
.
push_back
(
place
);
}
}
}
return
places
;
}
inline
std
::
string
action_place_to_string
(
ActionPlace
place
)
{
int
i
=
static_cast
<
int
>
(
place
);
if
(
i
==
0
)
{
return
"None"
;
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
MZone7
))
{
return
fmt
::
format
(
"m{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
SZone8
))
{
return
fmt
::
format
(
"s{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpMZone7
))
{
return
fmt
::
format
(
"om{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
+
1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpSZone8
))
{
return
fmt
::
format
(
"os{}"
,
i
-
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
+
1
);
}
else
{
return
"Unknown"
;
}
}
inline
std
::
pair
<
uint8_t
,
uint8_t
>
float_transform
(
int
x
)
{
...
...
@@ -807,6 +978,89 @@ using PlayerId = uint8_t;
using
CardCode
=
uint32_t
;
using
CardId
=
uint16_t
;
const
int
DESCRIPTION_LIMIT
=
10000
;
const
int
CARD_EFFECT_OFFSET
=
10010
;
class
LegalAction
{
public:
std
::
string
spec_
=
""
;
ActionAct
act_
=
ActionAct
::
None
;
ActionPhase
phase_
=
ActionPhase
::
None
;
bool
finish_
=
false
;
uint8_t
position_
=
0
;
int
effect_
=
-
1
;
uint8_t
number_
=
0
;
ActionPlace
place_
=
ActionPlace
::
None
;
uint8_t
attribute_
=
0
;
int
spec_index_
=
0
;
CardId
cid_
=
0
;
int
msg_
=
0
;
static
LegalAction
from_spec
(
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
act_spec
(
ActionAct
act
,
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
act_
=
act
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
finish
()
{
LegalAction
la
;
la
.
finish_
=
true
;
return
la
;
}
static
LegalAction
cancel
()
{
LegalAction
la
;
la
.
act_
=
ActionAct
::
Cancel
;
return
la
;
}
static
LegalAction
activate_spec
(
int
effect_idx
,
const
std
::
string
&
spec
)
{
LegalAction
la
;
la
.
act_
=
ActionAct
::
Activate
;
la
.
effect_
=
effect_idx
;
la
.
spec_
=
spec
;
return
la
;
}
static
LegalAction
phase
(
ActionPhase
phase
)
{
LegalAction
la
;
la
.
phase_
=
phase
;
return
la
;
}
static
LegalAction
number
(
uint8_t
number
)
{
LegalAction
la
;
la
.
number_
=
number
;
return
la
;
}
static
LegalAction
place
(
ActionPlace
place
)
{
LegalAction
la
;
la
.
place_
=
place
;
return
la
;
}
static
LegalAction
attribute
(
int
attribute
)
{
LegalAction
la
;
la
.
attribute_
=
attribute
;
return
la
;
}
};
class
SpecInfo
{
public:
uint16_t
index
;
CardId
cid
;
};
class
Card
{
friend
class
YGOProEnv
;
...
...
@@ -874,42 +1128,23 @@ public:
return
get_spec
(
player
!=
controler_
);
}
uint32_t
get_spec_code
(
PlayerId
player
)
const
{
return
ls_to_spec_code
(
location_
,
sequence_
,
position_
,
player
!=
controler_
);
}
std
::
string
get_position
()
const
{
return
position_to_string
(
position_
);
}
std
::
string
get_effect_description
(
uint32_t
desc
,
bool
existing
=
false
)
const
{
std
::
string
s
;
bool
e
=
false
;
auto
code
=
code_
;
if
(
desc
>
10000
)
{
code
=
desc
>>
4
;
std
::
string
get_effect_description
(
CardCode
code
,
int
effect_idx
)
const
{
if
(
code
==
0
)
{
return
get_system_string
(
effect_idx
);
}
uint32_t
offset
=
desc
-
code_
*
16
;
bool
in_range
=
(
offset
>=
0
)
&&
(
offset
<
strings_
.
size
());
std
::
string
str
=
""
;
if
(
in_range
)
{
str
=
ltrim
(
strings_
[
offset
]);
if
(
effect_idx
==
0
)
{
return
"default"
;
}
if
(
in_range
||
desc
==
0
)
{
if
((
desc
==
0
)
||
str
.
empty
())
{
s
=
"Activate "
+
name_
+
"."
;
}
else
{
s
=
name_
+
" ("
+
str
+
")"
;
e
=
true
;
}
}
else
{
s
=
get_system_string
(
desc
);
if
(
!
s
.
empty
())
{
e
=
true
;
}
effect_idx
-=
CARD_EFFECT_OFFSET
;
if
(
effect_idx
<
0
)
{
throw
std
::
runtime_error
(
fmt
::
format
(
"Invalid effect index: {}"
,
effect_idx
));
}
if
(
existing
&&
!
e
)
{
s
=
""
;
auto
s
=
strings_
[
effect_idx
];
if
(
s
.
empty
())
{
return
"effect "
+
std
::
to_string
(
effect_idx
);
}
return
s
;
}
...
...
@@ -1222,7 +1457,7 @@ public:
const
int
&
init_lp
()
const
{
return
init_lp_
;
}
virtual
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
=
0
;
virtual
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
=
0
;
};
class
GreedyAI
:
public
Player
{
...
...
@@ -1232,7 +1467,7 @@ public:
bool
verbose
=
false
)
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
return
0
;
}
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
return
0
;
}
};
class
RandomAI
:
public
Player
{
...
...
@@ -1246,8 +1481,8 @@ public:
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
),
gen_
(
seed
),
dist_
(
0
,
max_options
-
1
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
return
dist_
(
gen_
)
%
op
tions
.
size
();
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
return
dist_
(
gen_
)
%
ac
tions
.
size
();
}
};
...
...
@@ -1258,17 +1493,17 @@ public:
bool
verbose
=
false
)
:
Player
(
nickname
,
init_lp
,
duel_player
,
verbose
)
{}
int
think
(
const
std
::
vector
<
std
::
string
>
&
op
tions
)
override
{
int
think
(
const
std
::
vector
<
LegalAction
>
&
ac
tions
)
override
{
while
(
true
)
{
std
::
string
input
=
getline
();
if
(
input
==
"quit"
)
{
exit
(
0
);
}
auto
it
=
std
::
find
(
options
.
begin
(),
options
.
end
(),
input
)
;
if
(
i
t
!=
options
.
end
())
{
return
std
::
distance
(
options
.
begin
(),
it
)
;
int
idx
=
std
::
stoi
(
input
)
-
1
;
if
(
i
dx
>=
0
&&
idx
<
actions
.
size
())
{
return
idx
;
}
else
{
fmt
::
println
(
"{} Choose from {}
"
,
duel_player_
,
options
);
fmt
::
println
(
"{} Choose from {}
actions"
,
duel_player_
,
actions
.
size
()
);
}
}
}
...
...
@@ -1286,7 +1521,7 @@ public:
}
template
<
typename
Config
>
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
int
n_action_feats
=
1
3
;
int
n_action_feats
=
1
2
;
return
MakeDict
(
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
23
})),
...
...
@@ -1393,7 +1628,7 @@ protected:
int
turn_count_
;
int
msg_
;
std
::
vector
<
std
::
string
>
op
tions_
;
std
::
vector
<
LegalAction
>
legal_ac
tions_
;
PlayerId
to_play_
;
std
::
function
<
void
(
int
)
>
callback_
;
...
...
@@ -1423,9 +1658,10 @@ protected:
const
int
n_history_actions_
;
// circular buffer for history actions
TArray
<
uint8_t
>
history_actions_
;
int
ha_p_
=
0
;
std
::
vector
<
CardId
>
h_card_ids_
;
TArray
<
uint8_t
>
history_actions_1_
;
TArray
<
uint8_t
>
history_actions_2_
;
int
ha_p_1_
=
0
;
int
ha_p_2_
=
0
;
std
::
unordered_set
<
std
::
string
>
revealed_
;
...
...
@@ -1487,8 +1723,9 @@ public:
int
max_options
=
spec
.
config
[
"max_options"
_
];
int
n_action_feats
=
spec
.
state_spec
[
"obs:actions_"
_
].
shape
[
1
];
h_card_ids_
.
resize
(
max_options
);
history_actions_
=
TArray
<
uint8_t
>
(
Array
(
history_actions_1_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
history_actions_2_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
}
...
...
@@ -1560,8 +1797,10 @@ public:
turn_count_
=
0
;
ms_idx_
=
-
1
;
history_actions_
.
Zero
();
ha_p_
=
0
;
history_actions_1_
.
Zero
();
history_actions_2_
.
Zero
();
ha_p_1_
=
0
;
ha_p_2_
=
0
;
clock_t
_start
=
clock
();
...
...
@@ -1720,7 +1959,7 @@ public:
if
(
ms_mode_
==
0
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
const
auto
&
spec
=
ms_specs_
[
j
];
options_
.
push_back
(
spec
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
spec
)
);
}
}
else
{
ms_combs_
=
combs
;
...
...
@@ -1729,22 +1968,23 @@ public:
}
void
handle_multi_select
()
{
options_
=
{}
;
legal_actions_
.
clear
()
;
if
(
ms_mode_
==
0
)
{
for
(
int
j
=
0
;
j
<
ms_specs_
.
size
();
++
j
)
{
if
(
ms_spec2idx_
.
find
(
ms_specs_
[
j
])
!=
ms_spec2idx_
.
end
())
{
options_
.
push_back
(
ms_specs_
[
j
]);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
ms_specs_
[
j
]));
}
}
if
(
ms_idx_
==
ms_max_
-
1
)
{
if
(
ms_idx_
>=
ms_min_
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
}
callback_
=
[
this
](
int
idx
)
{
_callback_multi_select
(
idx
,
true
);
};
}
else
if
(
ms_idx_
>=
ms_min_
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
callback_
=
[
this
](
int
idx
)
{
_callback_multi_select
(
idx
,
false
);
};
...
...
@@ -1766,7 +2006,7 @@ public:
if
(
it
!=
ms_spec2idx_
.
end
())
{
return
it
->
second
;
}
// TODO: find the root cause
// TODO
(2)
: find the root cause
// print ms_spec2idx
show_deck
(
0
);
show_deck
(
1
);
...
...
@@ -1783,11 +2023,15 @@ public:
}
void
_callback_multi_select_2
(
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
idx
=
get_ms_spec_idx
(
option
);
const
auto
&
action
=
legal_ac
tions_
[
idx
];
idx
=
get_ms_spec_idx
(
action
.
spec_
);
if
(
idx
==
-
1
)
{
// TODO: find the root cause
fmt
::
println
(
"options: {}, idx: {}, option: {}"
,
options_
,
idx
,
option
);
// TODO(2): find the root cause
std
::
vector
<
std
::
string
>
specs
;
for
(
const
auto
&
la
:
legal_actions_
)
{
specs
.
push_back
(
la
.
spec_
);
}
fmt
::
println
(
"specs: {}, idx: {}, spec: {}"
,
specs
,
idx
,
action
.
spec_
);
throw
std
::
runtime_error
(
"Spec not found"
);
}
ms_r_idxs_
.
push_back
(
idx
);
...
...
@@ -1814,7 +2058,7 @@ public:
}
for
(
auto
&
i
:
comb
)
{
const
auto
&
spec
=
ms_specs_
[
i
];
options_
.
push_back
(
spec
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
spec
)
);
}
}
...
...
@@ -1831,17 +2075,21 @@ public:
}
void
_callback_multi_select
(
int
idx
,
bool
finish
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
// fmt::println("Select card: {}, finish: {}", option, finish);
if
(
option
==
"f"
)
{
if
(
action
.
finish_
)
{
finish
=
true
;
}
else
{
idx
=
get_ms_spec_idx
(
option
);
idx
=
get_ms_spec_idx
(
action
.
spec_
);
if
(
idx
!=
-
1
)
{
ms_r_idxs_
.
push_back
(
idx
);
}
else
{
// TODO: find the root cause
fmt
::
println
(
"options: {}, idx: {}, option: {}"
,
options_
,
idx
,
option
);
// TODO(2): find the root cause
std
::
vector
<
std
::
string
>
specs
;
for
(
const
auto
&
la
:
legal_actions_
)
{
specs
.
push_back
(
la
.
spec_
);
}
fmt
::
println
(
"specs: {}, idx: {}, spec: {}"
,
specs
,
idx
,
action
.
spec_
);
ms_idx_
=
-
1
;
resp_buf_
[
0
]
=
ms_min_
;
for
(
int
i
=
0
;
i
<
ms_min_
;
++
i
)
{
...
...
@@ -1860,27 +2108,27 @@ public:
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
}
else
{
ms_idx_
++
;
ms_spec2idx_
.
erase
(
option
);
ms_spec2idx_
.
erase
(
action
.
spec_
);
}
}
void
update_h_card_ids
(
PlayerId
player
,
int
idx
)
{
h_card_ids_
[
idx
]
=
parse_card_id
(
options_
[
idx
],
player
);
}
void
update_history_actions
(
PlayerId
player
,
int
idx
)
{
if
((
msg_
==
MSG_SELECT_CHAIN
)
&
(
options_
[
idx
][
0
]
==
'c'
))
{
void
update_history_actions
(
PlayerId
player
,
const
LegalAction
&
action
)
{
if
(
action
.
act_
==
ActionAct
::
Cancel
)
{
return
;
}
ha_p_
--
;
if
(
ha_p_
<
0
)
{
ha_p_
=
n_history_actions_
-
1
;
auto
&
ha_p
=
player
==
0
?
ha_p_1_
:
ha_p_2_
;
auto
&
history_actions
=
player
==
0
?
history_actions_1_
:
history_actions_2_
;
ha_p
--
;
if
(
ha_p
<
0
)
{
ha_p
=
n_history_actions_
-
1
;
}
history_actions_
[
ha_p_
].
Zero
();
_set_obs_action
(
history_actions_
,
ha_p_
,
msg_
,
options_
[
idx
],
{},
h_card_ids_
[
idx
]);
history_actions_
[
ha_p_
](
13
)
=
static_cast
<
uint8_t
>
(
player
);
history_actions_
[
ha_p_
](
14
)
=
static_cast
<
uint8_t
>
(
turn_count_
);
history_actions
[
ha_p
].
Zero
();
_set_obs_action
(
history_actions
,
ha_p
,
action
);
// Spec index not available in history actions
history_actions
[
ha_p
](
0
)
=
0
;
// history_actions[ha_p](12) = static_cast<uint8_t>(player);
history_actions
[
ha_p
](
12
)
=
static_cast
<
uint8_t
>
(
turn_count_
);
history_actions
[
ha_p
](
13
)
=
static_cast
<
uint8_t
>
(
phase_to_id
(
current_phase_
));
}
void
show_deck
(
const
std
::
vector
<
CardCode
>
&
deck
,
const
std
::
string
&
prefix
)
const
{
...
...
@@ -1910,18 +2158,18 @@ public:
}
void
show_history_actions
(
PlayerId
player
)
const
{
const
auto
&
ha
=
history_actions
_
;
const
auto
&
ha
=
player
==
0
?
history_actions_1_
:
history_actions_2
_
;
// print card ids of history actions
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
fmt
::
print
(
"history {}
\n
"
,
i
);
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
2
));
uint8_t
msg_id
=
uint8_t
(
ha
(
i
,
3
));
int
msg
=
_msgs
[
msg_id
-
1
];
fmt
::
print
(
"msg: {},"
,
msg_to_string
(
msg
));
uint8_t
v1
=
ha
(
i
,
0
);
uint8_t
v2
=
ha
(
i
,
1
);
uint8_t
v1
=
ha
(
i
,
1
);
uint8_t
v2
=
ha
(
i
,
2
);
CardId
card_id
=
(
static_cast
<
CardId
>
(
v1
)
<<
8
)
+
static_cast
<
CardId
>
(
v2
);
fmt
::
print
(
" {};"
,
card_id
);
for
(
int
j
=
3
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
for
(
int
j
=
4
;
j
<
ha
.
Shape
()[
1
];
j
++
)
{
fmt
::
print
(
" {}"
,
uint8_t
(
ha
(
i
,
j
)));
}
fmt
::
print
(
"
\n
"
);
...
...
@@ -1933,7 +2181,7 @@ public:
int
idx
=
action
[
"action"
_
];
callback_
(
idx
);
update_history_actions
(
to_play_
,
idx
);
update_history_actions
(
to_play_
,
legal_actions_
[
idx
]
);
PlayerId
player
=
to_play_
;
...
...
@@ -2012,10 +2260,10 @@ public:
}
private:
using
SpecIn
dex
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint16_t
>
;
using
SpecIn
fos
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
SpecInfo
>
;
std
::
tuple
<
SpecIn
dex
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
SpecIn
dex
spec2index
;
std
::
tuple
<
SpecIn
fos
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
SpecIn
fos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
...
...
@@ -2054,18 +2302,23 @@ private:
hide
=
false
;
}
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
_set_obs_card_
(
f_cards
,
offset
,
c
,
hide
);
offset
++
;
spec2index
[
spec
]
=
static_cast
<
uint16_t
>
(
offset
);
spec_infos
[
spec
]
=
{
static_cast
<
uint16_t
>
(
offset
),
card_id
};
}
}
}
}
return
{
spec
2index
,
loc_n_cards
};
return
{
spec
_infos
,
loc_n_cards
};
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
bool
hide
)
{
bool
hide
,
CardId
card_id
=
0
)
{
// check offset exceeds max_cards
uint8_t
location
=
c
.
location_
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
...
...
@@ -2077,7 +2330,6 @@ private:
}
if
(
!
hide
)
{
auto
card_id
=
c_get_card_id
(
c
.
code_
);
f_cards
(
offset
,
0
)
=
static_cast
<
uint8_t
>
(
card_id
>>
8
);
f_cards
(
offset
,
1
)
=
static_cast
<
uint8_t
>
(
card_id
&
0xff
);
}
...
...
@@ -2148,17 +2400,10 @@ private:
}
}
void
_set_obs_action_spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
std
::
string
&
spec
,
const
SpecIndex
&
spec2index
,
CardId
card_id
=
0
)
{
uint16_t
idx
;
if
(
spec2index
.
empty
())
{
idx
=
card_id
;
}
else
{
auto
it
=
spec2index
.
find
(
spec
);
if
(
it
==
spec2index
.
end
())
{
// TODO: find the root cause
const
SpecInfo
&
find_spec_info
(
SpecInfos
&
spec_infos
,
const
std
::
string
&
spec
)
{
auto
it
=
spec_infos
.
find
(
spec
);
if
(
it
==
spec_infos
.
end
())
{
// TODO(2): find the root cause
// print spec2index
show_deck
(
0
);
show_deck
(
1
);
...
...
@@ -2166,135 +2411,111 @@ private:
show_turn
();
fmt
::
println
(
"MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}"
,
ms_idx_
,
ms_mode_
,
ms_min_
,
ms_max_
,
ms_must_
,
ms_specs_
,
ms_combs_
);
fmt
::
println
(
"Spec: {}, Spec2index:"
,
spec
);
for
(
auto
&
[
k
,
v
]
:
spec
2index
)
{
fmt
::
print
(
"{}: {}
, "
,
k
,
v
);
for
(
auto
&
[
k
,
v
]
:
spec
_infos
)
{
fmt
::
print
(
"{}: {}
{}, "
,
k
,
v
.
index
,
v
.
cid
);
}
fmt
::
print
(
"
\n
"
);
// throw std::runtime_error("Spec not found: " + spec);
idx
=
1
;
}
else
{
idx
=
it
->
second
;
}
spec_infos
[
spec
]
=
{
1
,
1
};
return
spec_infos
[
spec
];
}
feat
(
i
,
0
)
=
static_cast
<
uint8_t
>
(
idx
>>
8
);
feat
(
i
,
1
)
=
static_cast
<
uint8_t
>
(
idx
&
0xff
);
}
void
_set_obs_action_msg
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
)
{
feat
(
i
,
2
)
=
msg_to_id
(
msg
);
return
it
->
second
;
}
void
_set_obs_action_
act
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
act
,
uint8_t
act_offset
=
0
)
{
feat
(
i
,
3
)
=
cmd_act_to_id
(
act
)
+
act_offset
;
void
_set_obs_action_
spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
idx
)
{
feat
(
i
,
0
)
=
static_cast
<
uint8_t
>
(
idx
)
;
}
void
_set_obs_action_yesno
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
yesno
)
{
feat
(
i
,
4
)
=
cmd_yesno_to_id
(
yesno
);
void
_set_obs_action_card_id
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
CardId
cid
)
{
feat
(
i
,
1
)
=
static_cast
<
uint8_t
>
(
cid
>>
8
);
feat
(
i
,
2
)
=
static_cast
<
uint8_t
>
(
cid
&
0xff
);
}
void
_set_obs_action_
phase
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
phase
)
{
feat
(
i
,
5
)
=
cmd_phase_to_id
(
phase
);
void
_set_obs_action_
msg
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
)
{
feat
(
i
,
3
)
=
msg_to_id
(
msg
);
}
void
_set_obs_action_
cancel
(
TArray
<
uint8_t
>
&
feat
,
int
i
)
{
feat
(
i
,
6
)
=
1
;
void
_set_obs_action_
act
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionAct
act
)
{
feat
(
i
,
4
)
=
static_cast
<
uint8_t
>
(
act
)
;
}
void
_set_obs_action_finish
(
TArray
<
uint8_t
>
&
feat
,
int
i
)
{
feat
(
i
,
7
)
=
1
;
feat
(
i
,
5
)
=
1
;
}
void
_set_obs_action_effect
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
effect
)
{
// 0: None
// 1: default
// 2-15: card effect
// 16+: system
if
(
effect
==
-
1
)
{
effect
=
0
;
}
else
if
(
effect
==
0
)
{
effect
=
1
;
}
else
if
(
effect
>=
CARD_EFFECT_OFFSET
)
{
effect
=
effect
-
CARD_EFFECT_OFFSET
+
2
;
}
else
{
effect
=
system_string_to_id
(
effect
);
}
feat
(
i
,
6
)
=
static_cast
<
uint8_t
>
(
effect
);
}
void
_set_obs_action_position
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
position
)
{
position
=
1
<<
(
position
-
'1'
);
feat
(
i
,
8
)
=
position_to_id
(
position
);
void
_set_obs_action_phase
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionPhase
phase
){
feat
(
i
,
7
)
=
static_cast
<
uint8_t
>
(
phase
);
}
void
_set_obs_action_
option
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
op
tion
)
{
feat
(
i
,
9
)
=
option
-
'0'
;
void
_set_obs_action_
position
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
posi
tion
)
{
feat
(
i
,
8
)
=
position_to_id
(
position
)
;
}
void
_set_obs_action_number
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
char
number
)
{
feat
(
i
,
10
)
=
number
-
'0'
;
void
_set_obs_action_number
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
number
)
{
feat
(
i
,
9
)
=
number
;
}
void
_set_obs_action_place
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
std
::
string
&
spec
)
{
feat
(
i
,
1
1
)
=
cmd_place_to_id
(
spec
);
void
_set_obs_action_place
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
ActionPlace
place
)
{
feat
(
i
,
1
0
)
=
static_cast
<
uint8_t
>
(
place
);
}
void
_set_obs_action_attrib
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
uint8_t
attrib
)
{
feat
(
i
,
1
2
)
=
attribute_to_id
(
attrib
);
feat
(
i
,
1
1
)
=
attribute_to_id
(
attrib
);
}
void
_set_obs_action
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
msg
,
const
std
::
string
&
option
,
const
SpecIndex
&
spec2index
,
CardId
card_id
)
{
void
_set_obs_action
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
const
LegalAction
&
action
)
{
auto
msg
=
action
.
msg_
;
_set_obs_action_msg
(
feat
,
i
,
msg
);
if
(
msg
==
MSG_SELECT_IDLECMD
)
{
if
(
option
==
"b"
||
option
==
"e"
)
{
_set_obs_action_phase
(
feat
,
i
,
option
[
0
]);
}
else
{
auto
act
=
option
[
0
];
auto
spec
=
option
.
substr
(
2
);
uint8_t
offset
=
0
;
int
n
=
spec
.
size
();
if
(
act
==
'v'
&&
std
::
isalpha
(
spec
[
n
-
1
]))
{
offset
=
spec
[
n
-
1
]
-
'a'
;
spec
=
spec
.
substr
(
0
,
n
-
1
);
}
_set_obs_action_act
(
feat
,
i
,
act
,
offset
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
}
}
else
if
(
msg
==
MSG_SELECT_CHAIN
)
{
if
(
option
[
0
]
==
'c'
)
{
_set_obs_action_cancel
(
feat
,
i
);
}
else
{
char
act
=
'v'
;
auto
spec
=
option
;
uint8_t
offset
=
0
;
auto
n
=
spec
.
size
();
if
(
std
::
isalpha
(
spec
[
n
-
1
]))
{
offset
=
spec
[
n
-
1
]
-
'a'
;
spec
=
spec
.
substr
(
0
,
n
-
1
);
}
_set_obs_action_act
(
feat
,
i
,
act
,
offset
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
}
}
else
if
(
msg
==
MSG_SELECT_CARD
||
msg
==
MSG_SELECT_TRIBUTE
||
msg
==
MSG_SELECT_SUM
||
msg
==
MSG_SELECT_UNSELECT_CARD
)
{
if
(
option
[
0
]
==
'f'
)
{
_set_obs_action_card_id
(
feat
,
i
,
action
.
cid_
);
if
(
msg
==
MSG_SELECT_CARD
||
msg
==
MSG_SELECT_TRIBUTE
||
msg
==
MSG_SELECT_SUM
||
msg
==
MSG_SELECT_UNSELECT_CARD
)
{
if
(
action
.
finish_
)
{
_set_obs_action_finish
(
feat
,
i
);
}
else
{
_set_obs_action_spec
(
feat
,
i
,
option
,
spec2index
,
card_id
);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
}
}
else
if
(
msg
==
MSG_SELECT_POSITION
)
{
_set_obs_action_position
(
feat
,
i
,
option
[
0
]
);
_set_obs_action_position
(
feat
,
i
,
action
.
position_
);
}
else
if
(
msg
==
MSG_SELECT_EFFECTYN
)
{
auto
spec
=
option
.
substr
(
2
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
_set_obs_action_yesno
(
feat
,
i
,
option
[
0
]);
}
else
if
(
msg
==
MSG_SELECT_YESNO
)
{
_set_obs_action_yesno
(
feat
,
i
,
option
[
0
]);
}
else
if
(
msg
==
MSG_SELECT_BATTLECMD
)
{
if
(
option
==
"m"
||
option
==
"e"
)
{
_set_obs_action_phase
(
feat
,
i
,
option
[
0
]);
}
else
{
auto
act
=
option
[
0
];
auto
spec
=
option
.
substr
(
2
);
_set_obs_action_act
(
feat
,
i
,
act
);
_set_obs_action_spec
(
feat
,
i
,
spec
,
spec2index
,
card_id
);
}
}
else
if
(
msg
==
MSG_SELECT_OPTION
)
{
_set_obs_action_option
(
feat
,
i
,
option
[
0
]);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
}
else
if
(
msg
==
MSG_SELECT_YESNO
||
msg
==
MSG_SELECT_OPTION
)
{
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
}
else
if
(
msg
==
MSG_SELECT_BATTLECMD
||
msg
==
MSG_SELECT_IDLECMD
||
msg
==
MSG_SELECT_CHAIN
)
{
_set_obs_action_phase
(
feat
,
i
,
action
.
phase_
);
_set_obs_action_spec
(
feat
,
i
,
action
.
spec_index_
);
_set_obs_action_act
(
feat
,
i
,
action
.
act_
);
_set_obs_action_effect
(
feat
,
i
,
action
.
effect_
);
}
else
if
(
msg
==
MSG_SELECT_PLACE
||
msg_
==
MSG_SELECT_DISFIELD
)
{
_set_obs_action_place
(
feat
,
i
,
option
);
_set_obs_action_place
(
feat
,
i
,
action
.
place_
);
}
else
if
(
msg
==
MSG_ANNOUNCE_ATTRIB
)
{
_set_obs_action_attrib
(
feat
,
i
,
1
<<
(
option
[
0
]
-
'1'
)
);
_set_obs_action_attrib
(
feat
,
i
,
action
.
attribute_
);
}
else
if
(
msg
==
MSG_ANNOUNCE_NUMBER
)
{
_set_obs_action_number
(
feat
,
i
,
option
[
0
]
);
_set_obs_action_number
(
feat
,
i
,
action
.
number_
);
}
else
{
throw
std
::
runtime_error
(
"Unsupported message "
+
std
::
to_string
(
msg
));
}
...
...
@@ -2302,49 +2523,42 @@ private:
CardId
spec_to_card_id
(
const
std
::
string
&
spec
,
PlayerId
player
)
{
int
offset
=
0
;
// TODO: possible info leak
bool
opponent
=
false
;
if
(
spec
[
0
]
==
'o'
)
{
player
=
1
-
player
;
opponent
=
true
;
offset
++
;
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
.
substr
(
offset
));
return
c_get_card_id
(
get_card_code
(
player
,
loc
,
seq
));
}
CardId
parse_card_id
(
const
std
::
string
&
option
,
PlayerId
player
)
{
CardId
card_id
=
0
;
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
if
(
!
(
option
==
"b"
||
option
==
"e"
))
{
auto
n
=
option
.
size
();
if
(
std
::
isalpha
(
option
[
n
-
1
]))
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
,
n
-
3
),
player
);
}
else
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
}
if
(
opponent
)
{
bool
hidden_for_opponent
=
true
;
if
(
loc
==
LOCATION_MZONE
||
loc
==
LOCATION_SZONE
||
loc
==
LOCATION_GRAVE
||
loc
==
LOCATION_REMOVED
)
{
hidden_for_opponent
=
false
;
}
}
else
if
(
msg_
==
MSG_SELECT_CHAIN
)
{
if
(
option
!=
"c"
)
{
card_id
=
spec_to_card_id
(
option
,
player
);
if
(
revealed_
.
size
()
!=
0
)
{
hidden_for_opponent
=
false
;
}
}
else
if
(
msg_
==
MSG_SELECT_CARD
||
msg_
==
MSG_SELECT_TRIBUTE
||
msg_
==
MSG_SELECT_SUM
||
msg_
==
MSG_SELECT_UNSELECT_CARD
)
{
if
(
option
[
0
]
!=
'f'
)
{
card_id
=
spec_to_card_id
(
option
,
player
);
if
(
hidden_for_opponent
)
{
return
0
;
}
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
}
else
if
(
msg_
==
MSG_SELECT_BATTLECMD
)
{
if
(
!
(
option
==
"m"
||
option
==
"e"
))
{
card_id
=
spec_to_card_id
(
option
.
substr
(
2
),
player
);
Card
c
=
get_card
(
player
,
loc
,
seq
);
bool
hide
=
c
.
position_
&
POS_FACEDOWN
;
if
(
revealed_
.
find
(
spec
)
!=
revealed_
.
end
())
{
hide
=
false
;
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
}
return
c
ard_id
;
return
c
_get_card_id
(
get_card_code
(
player
,
loc
,
seq
))
;
}
void
_set_obs_actions
(
TArray
<
uint8_t
>
&
feat
,
const
SpecIndex
&
spec2index
,
int
msg
,
const
std
::
vector
<
std
::
string
>
&
options
)
{
for
(
int
i
=
0
;
i
<
options
.
size
();
++
i
)
{
_set_obs_action
(
feat
,
i
,
msg
,
options
[
i
],
spec2index
,
0
);
void
_set_obs_actions
(
TArray
<
uint8_t
>
&
feat
,
const
std
::
vector
<
LegalAction
>
&
actions
)
{
for
(
int
i
=
0
;
i
<
actions
.
size
();
++
i
)
{
_set_obs_action
(
feat
,
i
,
actions
[
i
]);
}
}
...
...
@@ -2451,7 +2665,7 @@ private:
void
WriteState
(
float
reward
,
int
win_reason
=
0
)
{
State
state
=
Allocate
();
int
n_options
=
op
tions_
.
size
();
int
n_options
=
legal_ac
tions_
.
size
();
state
[
"reward"
_
]
=
reward
;
state
[
"info:to_play"
_
]
=
int
(
to_play_
);
state
[
"info:is_selfplay"
_
]
=
int
(
play_mode_
==
kSelfPlay
);
...
...
@@ -2463,62 +2677,69 @@ private:
return
;
}
auto
[
spec
2index
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec
_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
// we can't shuffle because idx must be stable in callback
if
(
n_options
>
max_options
())
{
op
tions_
.
resize
(
max_options
());
legal_ac
tions_
.
resize
(
max_options
());
}
// print spec2index
// for (auto const& [key, val] : spec2index) {
// fmt::println("{} {}", key, val);
// }
_set_obs_actions
(
state
[
"obs:actions_"
_
],
spec2index
,
msg_
,
options_
);
n_options
=
options_
.
size
();
n_options
=
legal_actions_
.
size
();
state
[
"info:num_options"
_
]
=
n_options
;
// update_h_card_ids from state
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
uint8_t
spec_index1
=
state
[
"obs:actions_"
_
](
i
,
0
)
;
uint8_t
spec_index2
=
state
[
"obs:actions_"
_
](
i
,
1
)
;
uint16_t
spec_index
=
(
static_cast
<
uint16_t
>
(
spec_index1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
spec_index2
)
;
if
(
spec_index
==
0
)
{
h_card_ids_
[
i
]
=
0
;
}
else
{
uint8_t
card_id1
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
0
);
uint8_t
card_id2
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
1
)
;
h_card_ids_
[
i
]
=
(
static_cast
<
uint16_t
>
(
card_id1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
card_id2
);
auto
&
action
=
legal_actions_
[
i
]
;
action
.
msg_
=
msg_
;
const
auto
&
spec
=
action
.
spec_
;
if
(
!
spec
.
empty
()
)
{
const
auto
&
spec_info
=
find_spec_info
(
spec_infos
,
spec
)
;
action
.
spec_index_
=
spec_info
.
index
;
if
(
action
.
cid_
==
0
)
{
action
.
cid_
=
spec_info
.
cid
;
}
}
}
_set_obs_actions
(
state
[
"obs:actions_"
_
],
legal_actions_
);
// write history actions
int
offset
=
n_history_actions_
-
ha_p_
;
int
n_h_action_feats
=
history_actions_
.
Shape
()[
1
];
auto
ha_p
=
to_play_
==
0
?
ha_p_1_
:
ha_p_2_
;
auto
&
history_actions
=
to_play_
==
0
?
history_actions_1_
:
history_actions_2_
;
int
offset
=
n_history_actions_
-
ha_p
;
int
n_h_action_feats
=
history_actions
.
Shape
()[
1
];
state
[
"obs:h_actions_"
_
].
Assign
(
(
uint8_t
*
)
history_actions
_
[
ha_p_
].
Data
(),
n_h_action_feats
*
offset
);
(
uint8_t
*
)
history_actions
[
ha_p
].
Data
(),
n_h_action_feats
*
offset
);
state
[
"obs:h_actions_"
_
][
offset
].
Assign
(
(
uint8_t
*
)
history_actions
_
.
Data
(),
n_h_action_feats
*
ha_p_
);
(
uint8_t
*
)
history_actions
.
Data
(),
n_h_action_feats
*
ha_p
);
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
if
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
2
))
==
0
)
{
if
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
3
))
==
0
)
{
break
;
}
state
[
"obs:h_actions_"
_
](
i
,
13
)
=
static_cast
<
uint8_t
>
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
13
))
==
to_play_
);
int
turn_diff
=
std
::
min
(
16
,
turn_count_
-
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
1
4
)));
state
[
"obs:h_actions_"
_
](
i
,
1
4
)
=
static_cast
<
uint8_t
>
(
turn_diff
);
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12
)) == to_play_);
int
turn_diff
=
std
::
min
(
16
,
turn_count_
-
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
1
2
)));
state
[
"obs:h_actions_"
_
](
i
,
1
2
)
=
static_cast
<
uint8_t
>
(
turn_diff
);
}
}
void
show_decision
(
int
idx
)
{
fmt
::
println
(
"Player {} chose
\"
{}
\"
in {}"
,
to_play_
,
options_
[
idx
],
options_
);
std
::
string
s
;
const
auto
&
a
=
legal_actions_
[
idx
];
if
(
!
a
.
spec_
.
empty
())
{
s
=
a
.
spec_
;
}
else
if
(
a
.
place_
!=
ActionPlace
::
None
)
{
s
=
action_place_to_string
(
a
.
place_
);
}
else
if
(
a
.
position_
!=
0
)
{
s
=
position_to_string
(
a
.
position_
);
}
else
{
s
=
fmt
::
format
(
"{}"
,
a
);
}
fmt
::
print
(
"Player {} chose
\"
{}
\"
in {}
\n
"
,
to_play_
,
s
,
legal_actions_
);
}
std
::
tuple
<
std
::
vector
<
CardCode
>
,
std
::
vector
<
CardCode
>
,
std
::
string
>
...
...
@@ -2581,15 +2802,19 @@ private:
handle_multi_select
();
}
else
{
handle_message
();
if
(
op
tions_
.
empty
())
{
if
(
legal_ac
tions_
.
empty
())
{
continue
;
}
}
if
((
play_mode_
==
kSelfPlay
)
||
(
to_play_
==
ai_player_
))
{
if
(
op
tions_
.
size
()
==
1
)
{
if
(
legal_ac
tions_
.
size
()
==
1
)
{
callback_
(
0
);
update_h_card_ids
(
to_play_
,
0
);
update_history_actions
(
to_play_
,
0
);
auto
la
=
legal_actions_
[
0
];
la
.
msg_
=
msg_
;
if
(
la
.
cid_
==
0
&&
!
la
.
spec_
.
empty
())
{
la
.
cid_
=
spec_to_card_id
(
la
.
spec_
,
to_play_
);
}
update_history_actions
(
to_play_
,
la
);
if
(
verbose_
)
{
show_decision
(
0
);
}
...
...
@@ -2597,7 +2822,7 @@ private:
return
;
}
}
else
{
auto
idx
=
players_
[
to_play_
]
->
think
(
op
tions_
);
auto
idx
=
players_
[
to_play_
]
->
think
(
legal_ac
tions_
);
callback_
(
idx
);
if
(
verbose_
)
{
show_decision
(
idx
);
...
...
@@ -2606,7 +2831,7 @@ private:
}
}
done_
=
true
;
op
tions_
.
clear
();
legal_ac
tions_
.
clear
();
}
uint8_t
read_u8
()
{
return
data_
[
dp_
++
];
}
...
...
@@ -2653,7 +2878,12 @@ private:
int32_t
bl
=
YGO_QueryCard
(
pduel_
,
player
,
loc
,
seq
,
flags
,
query_buf_
);
qdp_
=
0
;
if
(
bl
<=
0
)
{
throw
std
::
runtime_error
(
"[get_card] Invalid card (bl <= 0)"
);
show_deck
(
0
);
show_deck
(
1
);
show_turn
();
show_buffer
();
auto
s
=
fmt
::
format
(
"[get_card] Invalid card (bl <= 0), player: {}, loc: {}, seq: {}"
,
player
,
loc
,
seq
);
throw
std
::
runtime_error
(
s
);
}
uint32_t
f
=
q_read_u32
();
if
(
f
==
LEN_EMPTY
)
{
...
...
@@ -2728,7 +2958,7 @@ private:
c
.
attack_
=
q_read_u32
();
c
.
defense_
=
q_read_u32
();
// TODO: equip_target
// TODO
(2)
: equip_target
if
(
f
&
QUERY_EQUIP_CARD
)
{
q_read_u32
();
}
...
...
@@ -2744,7 +2974,7 @@ private:
cards
.
push_back
(
c_
);
}
// TODO: counters
// TODO
(2)
: counters
uint32_t
n_counters
=
q_read_u32
();
for
(
int
i
=
0
;
i
<
n_counters
;
++
i
)
{
if
(
i
==
0
)
{
...
...
@@ -2803,7 +3033,7 @@ private:
auto
controller
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
uint32_t
data
=
-
1
;
uint32_t
data
=
0
;
if
(
extra
)
{
if
(
extra8
)
{
data
=
read_u8
();
...
...
@@ -2816,6 +3046,23 @@ private:
return
card_specs
;
}
std
::
tuple
<
CardCode
,
int
>
unpack_desc
(
CardCode
code
,
uint32_t
desc
)
{
if
(
desc
<
DESCRIPTION_LIMIT
)
{
return
{
0
,
desc
};
}
CardCode
code_
=
desc
>>
4
;
int
idx
=
desc
&
0xf
;
if
(
idx
<
0
||
idx
>=
14
)
{
fmt
::
print
(
"Code: {}, Code_: {}, Desc: {}
\n
"
,
code
,
code_
,
desc
);
show_deck
(
0
);
show_deck
(
1
);
show_buffer
();
show_turn
();
throw
std
::
runtime_error
(
"Invalid effect index: "
+
std
::
to_string
(
idx
));
}
return
{
code_
,
idx
+
CARD_EFFECT_OFFSET
};
}
std
::
string
cardlist_info_for_player
(
const
Card
&
card
,
PlayerId
pl
)
{
std
::
string
spec
=
card
.
get_spec
(
pl
);
if
(
card
.
location_
==
LOCATION_DECK
)
{
...
...
@@ -2833,7 +3080,7 @@ private:
// 3. update to_play_ and options_ if need action
void
handle_message
()
{
msg_
=
int
(
data_
[
dp_
++
]);
op
tions_
=
{};
legal_ac
tions_
=
{};
if
(
verbose_
)
{
fmt
::
println
(
"Message {}, length {}, dp {}"
,
msg_to_string
(
msg_
),
dl_
,
dp_
);
...
...
@@ -3097,11 +3344,11 @@ private:
uint8_t
pos
=
read_u8
();
uint8_t
type
=
read_u8
();
uint32_t
value
=
read_u32
();
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
if
(
type
==
CHINT_RACE
)
{
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
std
::
string
races_str
=
"TODO"
;
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
...
...
@@ -3109,6 +3356,10 @@ private:
races_str
));
}
}
else
if
(
type
==
CHINT_ATTRIBUTE
)
{
Card
card
=
get_card
(
player
,
loc
,
seq
);
if
(
card
.
code_
==
0
)
{
return
;
}
std
::
string
attributes_str
=
"TODO"
;
for
(
PlayerId
pl
=
0
;
pl
<
2
;
pl
++
)
{
players_
[
pl
]
->
notify
(
fmt
::
format
(
"{} ({}) selected {}."
,
...
...
@@ -3229,7 +3480,7 @@ private:
return
;
}
dp_
+=
6
;
// TODO: implement output
// TODO
(3)
: implement output
}
else
if
(
msg_
==
MSG_CARD_TARGET
)
{
if
(
!
verbose_
)
{
dp_
=
dl_
;
...
...
@@ -3301,7 +3552,7 @@ private:
players_
[
pl
]
->
notify
(
str
);
}
}
else
if
(
msg_
==
MSG_SORT_CARD
)
{
// TODO: implement action
// TODO
(3)
: implement action
if
(
!
verbose_
)
{
dp_
=
dl_
;
resp_buf_
[
0
]
=
255
;
...
...
@@ -3374,7 +3625,7 @@ private:
auto
pl
=
players_
[
player
];
PlayerId
op_id
=
1
-
player
;
auto
op
=
players_
[
op_id
];
// TODO: counter type to string
// TODO
(3)
: counter type to string
pl
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
player
)));
op
->
notify
(
fmt
::
format
(
"{} counter(s) of type {} placed on {} ()."
,
count
,
"UNK"
,
c
.
name_
,
c
.
get_spec
(
op_id
)));
}
else
if
(
msg_
==
MSG_REMOVE_COUNTER
)
{
...
...
@@ -3406,7 +3657,7 @@ private:
dp_
=
dl_
;
return
;
}
// TODO: implement output
// TODO
(3)
: implement output
dp_
=
dl_
;
}
else
if
(
msg_
==
MSG_SHUFFLE_DECK
)
{
if
(
!
verbose_
)
{
...
...
@@ -3699,52 +3950,64 @@ private:
if
(
verbose_
)
{
pl
->
notify
(
"Battle menu:"
);
}
for
(
const
auto
[
code
,
spec
,
data
]
:
activatable
)
{
// TODO: Add effect description to indicate which effect is being activated
options_
.
push_back
(
"v "
+
spec
);
for
(
const
auto
[
code_t
,
spec
,
desc
]
:
activatable
)
{
CardCode
code
=
code_t
;
if
(
code
&
0x80000000
)
{
code
&=
0x7fffffff
;
}
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
auto
c
=
get_card
(
player
,
loc
,
seq
);
pl
->
notify
(
"v "
+
spec
+
": activate "
+
c
.
name_
+
" ("
+
std
::
to_string
(
c
.
attack_
)
+
"/"
+
std
::
to_string
(
c
.
defense_
)
+
")"
);
auto
c
=
c_get_card
(
code
);
int
cmd_idx
=
legal_actions_
.
size
(
);
std
::
string
s
=
fmt
::
format
(
"{}: activate {}({}) [{}/{}] ({})"
,
cmd_idx
,
c
.
name_
,
spec
,
c
.
attack_
,
c
.
defense_
,
c
.
get_effect_description
(
code_d
,
eff_idx
)
);
}
}
for
(
const
auto
[
code
,
spec
,
data
]
:
attackable
)
{
// TODO: add this as feature
bool
direct_attackable
=
data
&
0x1
;
options_
.
push_back
(
"a "
+
spec
);
auto
act
=
direct_attackable
?
ActionAct
::
DirectAttack
:
ActionAct
::
Attack
;
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
act
,
spec
));
if
(
verbose_
)
{
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
auto
c
=
get_card
(
player
,
loc
,
seq
);
std
::
string
s
;
auto
[
controller
,
loc
,
seq
,
pos
]
=
spec_to_ls
(
player
,
spec
);
auto
c
=
get_card
(
controller
,
loc
,
seq
);
int
cmd_idx
=
legal_actions_
.
size
();
auto
attack_str
=
direct_attackable
?
"direct attack"
:
"attack"
;
std
::
string
s
=
fmt
::
format
(
"{}: {} {}({}) "
,
cmd_idx
,
attack_str
,
c
.
name_
,
spec
);
if
(
c
.
type_
&
TYPE_LINK
)
{
s
=
"a "
+
spec
+
": "
+
c
.
name_
+
" ("
+
std
::
to_string
(
c
.
attack_
)
+
")"
;
s
+=
fmt
::
format
(
"[{}]"
,
c
.
attack_
);
}
else
{
s
=
"a "
+
spec
+
": "
+
c
.
name_
+
" ("
+
std
::
to_string
(
c
.
attack_
)
+
"/"
+
std
::
to_string
(
c
.
defense_
)
+
")"
;
}
if
(
direct_attackable
)
{
s
+=
" direct attack"
;
}
else
{
s
+=
" attack"
;
s
+=
fmt
::
format
(
"[{}/{}]"
,
c
.
attack_
,
c
.
defense_
);
}
pl
->
notify
(
s
);
}
}
if
(
to_m2
)
{
options_
.
push_back
(
"m"
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
Main2
));
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
pl
->
notify
(
"m: Main phase 2."
);
pl
->
notify
(
fmt
::
format
(
"{}: Main phase 2."
,
cmd_idx
)
);
}
}
if
(
to_ep
)
{
if
(
!
to_m2
)
{
options_
.
push_back
(
"e"
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
End
));
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
pl
->
notify
(
"e: End phase."
);
pl
->
notify
(
fmt
::
format
(
"{}: End phase."
,
cmd_idx
)
);
}
}
}
...
...
@@ -3752,14 +4015,15 @@ private:
int
n_attackables
=
attackable
.
size
();
to_play_
=
player
;
callback_
=
[
this
,
n_activatables
,
n_attackables
,
to_ep
,
to_m2
](
int
idx
)
{
const
auto
&
la
=
legal_actions_
[
idx
];
if
(
idx
<
n_activatables
)
{
YGO_SetResponsei
(
pduel_
,
idx
<<
16
);
}
else
if
(
idx
<
(
n_activatables
+
n_attackables
))
{
idx
=
idx
-
n_activatables
;
YGO_SetResponsei
(
pduel_
,
(
idx
<<
16
)
+
1
);
}
else
if
((
options_
[
idx
]
==
"e"
)
&&
to_ep
)
{
}
else
if
((
la
.
phase_
==
ActionPhase
::
End
)
&&
to_ep
)
{
YGO_SetResponsei
(
pduel_
,
3
);
}
else
if
((
options_
[
idx
]
==
"m"
)
&&
to_m2
)
{
}
else
if
((
la
.
phase_
==
ActionPhase
::
Main2
)
&&
to_m2
)
{
YGO_SetResponsei
(
pduel_
,
2
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
...
...
@@ -3777,21 +4041,18 @@ private:
std
::
vector
<
std
::
string
>
select_specs
;
select_specs
.
reserve
(
select_size
);
if
(
verbose_
)
{
std
::
vector
<
Card
>
cards
;
auto
pl
=
players_
[
player
];
pl
->
notify
(
"Select "
+
std
::
to_string
(
min
)
+
" to "
+
std
::
to_string
(
max
)
+
" cards:"
);
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
auto
code
=
read_u32
();
auto
loc
=
read_u32
();
Card
card
=
c_get_card
(
code
);
card
.
set_location
(
loc
);
cards
.
push_back
(
card
);
}
auto
pl
=
players_
[
player
];
pl
->
notify
(
"Select "
+
std
::
to_string
(
min
)
+
" to "
+
std
::
to_string
(
max
)
+
" cards:"
);
for
(
const
auto
&
card
:
cards
)
{
auto
spec
=
card
.
get_spec
(
player
);
select_specs
.
push_back
(
spec
);
pl
->
notify
(
spec
+
": "
+
card
.
name_
);
auto
s
=
fmt
::
format
(
"{}: {}({})"
,
i
+
1
,
card
.
name_
,
spec
);
pl
->
notify
(
s
);
}
}
else
{
for
(
int
i
=
0
;
i
<
select_size
;
++
i
)
{
...
...
@@ -3807,22 +4068,22 @@ private:
auto
unselect_size
=
read_u8
();
// unselect not allowed (no regrets
!
)
// unselect not allowed (no regrets)
dp_
+=
8
*
unselect_size
;
for
(
int
j
=
0
;
j
<
select_specs
.
size
();
++
j
)
{
options_
.
push_back
(
select_specs
[
j
]
);
legal_actions_
.
push_back
(
LegalAction
::
from_spec
(
select_specs
[
j
])
);
}
if
(
finishable
)
{
options_
.
push_back
(
"f"
);
legal_actions_
.
push_back
(
LegalAction
::
finish
()
);
}
// cancelable and finishable not needed
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
options_
[
idx
]
==
"f"
)
{
if
(
legal_actions_
[
idx
].
finish_
)
{
YGO_SetResponsei
(
pduel_
,
-
1
);
}
else
{
resp_buf_
[
0
]
=
1
;
...
...
@@ -3893,7 +4154,7 @@ private:
}
}
// TODO: use this when added to history actions
// TODO
(1)
: use this when added to history actions
// if ((min == max) && (max == specs.size())) {
// resp_buf_[0] = specs.size();
// for (int i = 0; i < specs.size(); ++i) {
...
...
@@ -3974,7 +4235,7 @@ private:
// combs = combinations_with_weight(release_params, min);
}
// TODO: use this when added to history actions
// TODO
(1)
: use this when added to history actions
// if (max == specs.size()) {
// // tribute all
// resp_buf_[0] = specs.size();
...
...
@@ -4126,25 +4387,18 @@ private:
// auto hint_timing = read_u32();
// auto other_timing = read_u32();
std
::
vector
<
Card
>
card
s
;
std
::
vector
<
Card
Code
>
code
s
;
std
::
vector
<
uint32_t
>
descs
;
std
::
vector
<
uint32_t
>
spec_code
s
;
std
::
vector
<
std
::
string
>
spec
s
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
auto
et
=
read_u8
();
auto
flag
=
read_u8
();
CardCode
code
=
read_u32
();
if
(
verbose_
)
{
uint32_t
loc
=
read_u32
();
Card
card
=
c_get_card
(
code
);
card
.
set_location
(
loc
);
cards
.
push_back
(
card
);
spec_codes
.
push_back
(
card
.
get_spec_code
(
player
));
}
else
{
PlayerId
c
=
read_u8
();
uint8_t
loc
=
read_u8
();
uint8_t
seq
=
read_u8
();
uint8_t
pos
=
read_u8
();
spec_codes
.
push_back
(
ls_to_spec_code
(
loc
,
seq
,
pos
,
c
!=
player
));
}
codes
.
push_back
(
code
);
PlayerId
c
=
read_u8
();
uint8_t
loc
=
read_u8
();
uint8_t
seq
=
read_u8
();
uint8_t
pos
=
read_u8
();
specs
.
push_back
(
ls_to_spec
(
loc
,
seq
,
pos
,
c
!=
player
));
uint32_t
desc
=
read_u32
();
descs
.
push_back
(
desc
);
}
...
...
@@ -4168,58 +4422,42 @@ private:
op
->
seen_waiting_
=
true
;
}
std
::
vector
<
int
>
chain_index
;
ankerl
::
unordered_dense
::
map
<
uint32_t
,
int
>
chain_counts
;
ankerl
::
unordered_dense
::
map
<
uint32_t
,
int
>
chain_orders
;
std
::
vector
<
std
::
string
>
chain_specs
;
std
::
vector
<
std
::
string
>
effect_descs
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
chain_index
.
push_back
(
i
);
chain_counts
[
spec_codes
[
i
]]
+=
1
;
if
(
verbose_
)
{
pl
->
notify
(
"Select chain:"
);
}
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
auto
spec_code
=
spec_codes
[
i
];
auto
cs
=
code_to_spec
(
spec_code
);
auto
chain_count
=
chain_counts
[
spec_code
];
if
(
chain_count
>
1
)
{
// TODO: should use desc to indicate activate which effect
cs
.
push_back
(
'a'
+
chain_orders
[
spec_code
]);
}
chain_orders
[
spec_code
]
++
;
chain_specs
.
push_back
(
cs
);
if
(
verbose_
)
{
const
auto
&
card
=
cards
[
i
];
effect_descs
.
push_back
(
card
.
get_effect_description
(
descs
[
i
],
true
));
CardCode
code
=
codes
[
i
];
uint32_t
desc
=
descs
[
i
];
auto
spec
=
specs
[
i
];
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
}
if
(
verbose_
)
{
if
(
forced
)
{
pl
->
notify
(
"Select chain:"
);
}
else
{
pl
->
notify
(
"Select chain (c to cancel):"
);
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
const
auto
&
effect_desc
=
effect_descs
[
i
];
if
(
effect_desc
.
empty
())
{
pl
->
notify
(
chain_specs
[
i
]
+
": "
+
cards
[
i
].
name_
);
}
else
{
pl
->
notify
(
chain_specs
[
i
]
+
" ("
+
cards
[
i
].
name_
+
"): "
+
effect_desc
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
auto
c
=
c_get_card
(
code
);
std
::
string
s
=
fmt
::
format
(
"{}: {}({}) ({})"
,
i
+
1
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
pl
->
notify
(
s
);
}
}
for
(
const
auto
&
spec
:
chain_specs
)
{
options_
.
push_back
(
spec
);
}
if
(
!
forced
)
{
options_
.
push_back
(
"c"
);
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
if
(
verbose_
)
{
pl
->
notify
(
fmt
::
format
(
"{}: cancel"
,
size
+
1
));
}
}
to_play_
=
player
;
callback_
=
[
this
,
forced
](
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
if
(
option
==
"c"
)
{
const
auto
&
action
=
legal_ac
tions_
[
idx
];
if
(
action
.
act_
==
ActionAct
::
Cancel
)
{
if
(
forced
)
{
fmt
::
print
(
"cancel not allowed in forced chain
\n
"
);
YGO_SetResponsei
(
pduel_
,
0
);
...
...
@@ -4232,58 +4470,76 @@ private:
};
}
else
if
(
msg_
==
MSG_SELECT_YESNO
)
{
auto
player
=
read_u8
();
auto
desc
=
read_u32
();
auto
[
code
,
eff_idx
]
=
unpack_desc
(
0
,
desc
);
if
(
desc
==
0
)
{
show_buffer
();
auto
s
=
fmt
::
format
(
"Unknown desc {} in select_yesno"
,
desc
);
throw
std
::
runtime_error
(
s
);
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
""
);
if
(
code
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
auto
desc
=
read_u32
();
auto
pl
=
players_
[
player
];
std
::
string
opt
;
if
(
desc
>
10000
)
{
auto
code
=
desc
>>
4
;
auto
card
=
c_get_card
(
code
);
auto
opt_idx
=
desc
&
0xf
;
if
(
opt_idx
<
card
.
strings_
.
size
())
{
opt
=
card
.
strings_
[
opt_idx
];
std
::
string
s
;
if
(
code
==
0
)
{
s
=
get_system_string
(
eff_idx
);
}
else
{
Card
c
=
c_get_card
(
code
);
int
cmd_idx
=
legal_actions_
.
size
();
eff_idx
-=
CARD_EFFECT_OFFSET
;
if
(
eff_idx
>=
c
.
strings_
.
size
())
{
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown effect {} of {}"
,
eff_idx
,
c
.
name_
));
}
if
(
opt
.
empty
())
{
opt
=
"Unknown question from "
+
card
.
name_
+
". Yes or no?"
;
auto
str
=
c
.
strings_
[
eff_idx
];
if
(
str
.
empty
())
{
str
=
"effect "
+
std
::
to_string
(
eff_idx
);
}
}
else
{
opt
=
get_system_string
(
desc
);
s
=
fmt
::
format
(
"{} ({})"
,
c
.
name_
,
str
);
}
pl
->
notify
(
opt
);
pl
->
notify
(
"Please enter y or n."
);
}
else
{
dp_
+=
4
;
pl
->
notify
(
"1: "
+
s
);
pl
->
notify
(
"2: No"
);
}
options_
=
{
"y"
,
"n"
};
// TODO: maybe add card id to cancel
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
YGO_SetResponsei
(
pduel_
,
1
);
}
else
if
(
idx
==
1
)
{
YGO_SetResponsei
(
pduel_
,
0
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
}
};
}
else
if
(
msg_
==
MSG_SELECT_EFFECTYN
)
{
auto
player
=
read_u8
();
std
::
string
spec
;
CardCode
code
=
read_u32
();
auto
ct
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
pos
=
read_u8
();
auto
desc
=
read_u32
();
std
::
string
spec
=
ls_to_spec
(
loc
,
seq
,
pos
,
ct
!=
player
);
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
CardCode
code
=
read_u32
();
uint32_t
loc
=
read_u32
();
Card
card
=
c_get_card
(
code
);
card
.
set_location
(
loc
);
auto
desc
=
read_u32
();
Card
c
=
c_get_card
(
code
);
auto
pl
=
players_
[
player
];
spec
=
card
.
get_spec
(
player
);
auto
name
=
card
.
name_
;
auto
name
=
c
.
name_
;
std
::
string
s
;
if
(
desc
==
0
)
{
// From [%ls], activate [%ls]?
s
=
"From "
+
card
.
get_spec
(
player
)
+
", activate "
+
name
+
"?"
;
}
else
if
(
desc
<
2048
)
{
if
(
code_d
==
0
)
{
s
=
get_system_string
(
desc
);
std
::
string
fmt_str
=
"[%ls]"
;
auto
pos
=
find_substrs
(
s
,
fmt_str
);
...
...
@@ -4295,87 +4551,74 @@ private:
}
else
if
(
pos
.
size
()
==
2
)
{
auto
p1
=
pos
[
0
];
auto
p2
=
pos
[
1
];
s
=
s
.
substr
(
0
,
p1
)
+
card
.
get_spec
(
player
)
+
s
=
s
.
substr
(
0
,
p1
)
+
spec
+
s
.
substr
(
p1
+
fmt_str
.
size
(),
p2
-
p1
-
fmt_str
.
size
())
+
name
+
s
.
substr
(
p2
+
fmt_str
.
size
());
}
else
{
throw
std
::
runtime_error
(
"Unknown effectyn desc "
+
std
::
to_string
(
desc
)
+
" of "
+
name
);
}
}
else
if
(
desc
<
10000u
)
{
s
=
get_system_string
(
desc
);
}
else
{
CardCode
code
=
(
desc
>>
4
)
&
0x0fffffff
;
uint32_t
offset
=
desc
&
0xf
;
if
(
cards_
.
find
(
code
)
!=
cards_
.
end
())
{
auto
&
card_
=
c_get_card
(
code
);
s
=
card_
.
strings_
[
offset
];
if
(
s
.
empty
())
{
s
=
"???"
;
}
}
else
{
throw
std
::
runtime_error
(
"Unknown effectyn desc "
+
std
::
to_string
(
desc
)
+
" of "
+
name
);
}
s
=
fmt
::
format
(
"{}({}) ({})"
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
}
pl
->
notify
(
s
);
pl
->
notify
(
"Please enter y or n."
);
}
else
{
dp_
+=
4
;
auto
c
=
read_u8
();
auto
loc
=
read_u8
();
auto
seq
=
read_u8
();
auto
pos
=
read_u8
();
dp_
+=
4
;
spec
=
ls_to_spec
(
loc
,
seq
,
pos
,
c
!=
player
);
pl
->
notify
(
"1: "
+
s
);
pl
->
notify
(
"2: No"
);
}
options_
=
{
"y "
+
spec
,
"n "
+
spec
};
// TODO: maybe add card info to cancel
legal_actions_
.
push_back
(
LegalAction
::
cancel
());
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
idx
==
0
)
{
YGO_SetResponsei
(
pduel_
,
1
);
}
else
if
(
idx
==
1
)
{
YGO_SetResponsei
(
pduel_
,
0
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option"
);
}
};
}
else
if
(
msg_
==
MSG_SELECT_OPTION
)
{
// TODO: add card information
auto
player
=
read_u8
();
auto
size
=
read_u8
();
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
pl
->
notify
(
"Select an option:"
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
auto
opt
=
read_u32
();
players_
[
player
]
->
notify
(
"Select an option:"
);
}
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
auto
desc
=
read_u32
();
auto
[
code
,
eff_idx
]
=
unpack_desc
(
0
,
desc
);
if
(
desc
==
0
)
{
show_buffer
();
auto
s
=
fmt
::
format
(
"Unknown desc {} in select_option"
,
desc
);
throw
std
::
runtime_error
(
s
);
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
""
);
if
(
code
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
std
::
string
s
;
if
(
opt
>
10000
)
{
CardCode
code
=
opt
>>
4
;
s
=
c_get_card
(
code
).
strings_
[
opt
&
0xf
];
if
(
code
==
0
)
{
s
=
get_system_string
(
eff_idx
);
}
else
{
s
=
get_system_string
(
opt
);
Card
c
=
c_get_card
(
code
);
int
cmd_idx
=
legal_actions_
.
size
();
eff_idx
-=
CARD_EFFECT_OFFSET
;
if
(
eff_idx
>=
c
.
strings_
.
size
())
{
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown effect {} of {}"
,
eff_idx
,
c
.
name_
));
}
auto
str
=
c
.
strings_
[
eff_idx
];
if
(
str
.
empty
())
{
str
=
"effect "
+
std
::
to_string
(
eff_idx
);
}
s
=
fmt
::
format
(
"{} ({})"
,
c
.
name_
,
str
);
}
std
::
string
option
=
std
::
to_string
(
i
+
1
);
options_
.
push_back
(
option
);
pl
->
notify
(
option
+
": "
+
s
);
}
}
else
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
dp_
+=
4
;
options_
.
push_back
(
std
::
to_string
(
i
+
1
));
players_
[
player
]
->
notify
(
std
::
to_string
(
i
+
1
)
+
": "
+
s
);
}
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
if
(
verbose_
)
{
players_
[
to_play_
]
->
notify
(
"You selected option "
+
options_
[
idx
]
+
"."
);
players_
[
1
-
to_play_
]
->
notify
(
players_
[
to_play_
]
->
nickname_
+
" selected option "
+
options_
[
idx
]
+
"."
);
}
YGO_SetResponsei
(
pduel_
,
idx
);
};
}
else
if
(
msg_
==
MSG_SELECT_IDLECMD
)
{
...
...
@@ -4397,90 +4640,97 @@ private:
pl
->
notify
(
"Select a card and action to perform."
);
}
for
(
const
auto
&
[
code
,
spec
,
data
]
:
summonable_
)
{
std
::
string
option
=
"s "
+
spec
;
options_
.
push_back
(
option
);
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Summon
,
spec
));
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Summon "
+
name
+
" in face-up attack position."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Summon {} in face-up attack position"
,
cmd_idx
,
name
));
}
}
offset
+=
summonable_
.
size
();
int
spsummon_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
spsummon_
)
{
std
::
string
option
=
"c "
+
spec
;
options_
.
push_back
(
option
);
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
SpSummon
,
spec
));
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Special summon "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Special summon {}"
,
cmd_idx
,
name
));
}
}
offset
+=
spsummon_
.
size
();
int
repos_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
repos_
)
{
std
::
string
option
=
"r "
+
spec
;
options_
.
push_back
(
option
);
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Repo
,
spec
));
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Reposition "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Change position of {}"
,
cmd_idx
,
name
));
}
}
offset
+=
repos_
.
size
();
int
mset_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_mset_
)
{
std
::
string
option
=
"m "
+
spec
;
options_
.
push_back
(
option
);
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
MSet
,
spec
));
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Summon "
+
name
+
" in face-down defense position."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Summon {} in face-down defense position"
,
cmd_idx
,
name
));
}
}
offset
+=
idle_mset_
.
size
();
int
set_offset
=
offset
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_set_
)
{
std
::
string
option
=
"t "
+
spec
;
options_
.
push_back
(
option
);
legal_actions_
.
push_back
(
LegalAction
::
act_spec
(
ActionAct
::
Set
,
spec
));
if
(
verbose_
)
{
const
auto
&
name
=
c_get_card
(
code
).
name_
;
pl
->
notify
(
option
+
": Set "
+
name
+
"."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Set {}"
,
cmd_idx
,
name
));
}
}
offset
+=
idle_set_
.
size
();
int
activate_offset
=
offset
;
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
idle_activate_count
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_activate_
)
{
idle_activate_count
[
spec
]
+=
1
;
}
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
activate_count
;
for
(
const
auto
&
[
code
,
spec
,
data
]
:
idle_activate_
)
{
// TODO: use effect description to indicate which effect to activate
std
::
string
option
=
"v "
+
spec
;
int
count
=
idle_activate_count
[
spec
];
activate_count
[
spec
]
++
;
if
(
count
>
1
)
{
option
.
push_back
(
'a'
+
activate_count
[
spec
]
-
1
);
for
(
const
auto
&
[
code_t
,
spec
,
desc
]
:
idle_activate_
)
{
CardCode
code
=
code_t
;
if
(
code
&
0x80000000
)
{
code
&=
0x7fffffff
;
}
options_
.
push_back
(
option
);
auto
[
code_d
,
eff_idx
]
=
unpack_desc
(
code
,
desc
);
if
(
desc
==
0
)
{
code_d
=
code
;
}
auto
la
=
LegalAction
::
activate_spec
(
eff_idx
,
spec
);
if
(
code_d
!=
0
)
{
la
.
cid_
=
c_get_card_id
(
code_d
);
}
legal_actions_
.
push_back
(
la
);
if
(
verbose_
)
{
pl
->
notify
(
option
+
": "
+
c_get_card
(
code
).
get_effect_description
(
data
));
auto
c
=
c_get_card
(
code
);
int
cmd_idx
=
legal_actions_
.
size
();
std
::
string
s
=
fmt
::
format
(
"{}: Activate {}({}) ({})"
,
cmd_idx
,
c
.
name_
,
spec
,
c
.
get_effect_description
(
code_d
,
eff_idx
));
pl
->
notify
(
s
);
}
}
if
(
to_bp_
)
{
std
::
string
cmd
=
"b"
;
options_
.
push_back
(
cmd
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
Battle
));
if
(
verbose_
)
{
pl
->
notify
(
cmd
+
": Enter the battle phase."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: Enter the battle phase."
,
cmd_idx
));
}
}
if
(
to_ep_
)
{
if
(
!
to_bp_
)
{
std
::
string
cmd
=
"e"
;
options_
.
push_back
(
cmd
);
legal_actions_
.
push_back
(
LegalAction
::
phase
(
ActionPhase
::
End
));
if
(
verbose_
)
{
pl
->
notify
(
cmd
+
": End phase."
);
int
cmd_idx
=
legal_actions_
.
size
();
pl
->
notify
(
fmt
::
format
(
"{}: End phase."
,
cmd_idx
));
}
}
}
...
...
@@ -4488,104 +4738,90 @@ private:
to_play_
=
player
;
callback_
=
[
this
,
spsummon_offset
,
repos_offset
,
mset_offset
,
set_offset
,
activate_offset
](
int
idx
)
{
const
auto
&
option
=
options_
[
idx
];
char
cmd
=
option
[
0
];
if
(
cmd
==
'b'
)
{
const
auto
&
action
=
legal_actions_
[
idx
];
if
(
action
.
phase_
==
ActionPhase
::
Battle
)
{
YGO_SetResponsei
(
pduel_
,
6
);
}
else
if
(
cmd
==
'e'
)
{
}
else
if
(
action
.
phase_
==
ActionPhase
::
End
)
{
YGO_SetResponsei
(
pduel_
,
7
);
}
else
{
auto
spec
=
option
.
substr
(
2
)
;
if
(
cmd
==
's'
)
{
auto
act
=
action
.
act_
;
if
(
act
==
ActionAct
::
Summon
)
{
uint32_t
idx_
=
idx
;
YGO_SetResponsei
(
pduel_
,
idx_
<<
16
);
}
else
if
(
cmd
==
'c'
)
{
}
else
if
(
act
==
ActionAct
::
SpSummon
)
{
uint32_t
idx_
=
idx
-
spsummon_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
1
);
}
else
if
(
cmd
==
'r'
)
{
}
else
if
(
act
==
ActionAct
::
Repo
)
{
uint32_t
idx_
=
idx
-
repos_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
2
);
}
else
if
(
cmd
==
'm'
)
{
}
else
if
(
act
==
ActionAct
::
MSet
)
{
uint32_t
idx_
=
idx
-
mset_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
3
);
}
else
if
(
cmd
==
't'
)
{
}
else
if
(
act
==
ActionAct
::
Set
)
{
uint32_t
idx_
=
idx
-
set_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
4
);
}
else
if
(
cmd
==
'v'
)
{
}
else
if
(
act
==
ActionAct
::
Activate
)
{
uint32_t
idx_
=
idx
-
activate_offset
;
YGO_SetResponsei
(
pduel_
,
(
idx_
<<
16
)
+
5
);
}
else
{
throw
std
::
runtime_error
(
"Invalid option: "
+
option
);
}
}
};
}
else
if
(
msg_
==
MSG_SELECT_PLACE
)
{
}
else
if
(
msg_
==
MSG_SELECT_PLACE
||
msg_
==
MSG_SELECT_DISFIELD
)
{
// TODO(1): add card informaton to select place
auto
player
=
read_u8
();
auto
count
=
read_u8
();
if
(
count
==
0
)
{
count
=
1
;
}
auto
flag
=
read_u32
();
options_
=
flag_to_usable_cardspecs
(
flag
);
if
(
verbose_
)
{
std
::
string
specs_str
=
options_
[
0
];
for
(
int
i
=
1
;
i
<
options_
.
size
();
++
i
)
{
specs_str
+=
", "
+
options_
[
i
];
}
if
(
count
==
1
)
{
players_
[
player
]
->
notify
(
"Select place for card, one of "
+
specs_str
+
"."
);
}
else
{
players_
[
player
]
->
notify
(
"Select "
+
std
::
to_string
(
count
)
+
" places for card, from "
+
specs_str
+
"."
);
}
}
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
std
::
string
spec
=
options_
[
idx
];
auto
plr
=
player
;
if
(
spec
[
0
]
==
'o'
)
{
plr
=
1
-
player
;
spec
=
spec
.
substr
(
1
);
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
resp_buf_
[
0
]
=
plr
;
resp_buf_
[
1
]
=
loc
;
resp_buf_
[
2
]
=
seq
;
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
};
}
else
if
(
msg_
==
MSG_SELECT_DISFIELD
)
{
auto
player
=
read_u8
();
auto
count
=
read_u8
();
if
(
count
==
0
)
{
count
=
1
;
if
(
count
!=
1
)
{
auto
s
=
fmt
::
format
(
"Select place count {} not implemented for {}"
,
count
,
msg_
==
MSG_SELECT_PLACE
?
"place"
:
"disfield"
);
throw
std
::
runtime_error
(
s
);
}
auto
flag
=
read_u32
();
options_
=
flag_to_usable_cardspec
s
(
flag
);
auto
places
=
flag_to_usable_place
s
(
flag
);
if
(
verbose_
)
{
std
::
string
specs_str
=
options_
[
0
];
for
(
int
i
=
1
;
i
<
options_
.
size
();
++
i
)
{
specs_str
+=
", "
+
options_
[
i
];
}
if
(
count
==
1
)
{
players_
[
player
]
->
notify
(
"Select place for card, one of "
+
specs_str
+
"."
);
}
else
{
throw
std
::
runtime_error
(
"Select disfield count "
+
std
::
to_string
(
count
)
+
" not implemented"
);
auto
place_s
=
msg_
==
MSG_SELECT_PLACE
?
"place"
:
"disfield"
;
auto
s
=
fmt
::
format
(
"Select {} for card, one of:"
,
place_s
);
players_
[
player
]
->
notify
(
s
);
}
for
(
int
i
=
0
;
i
<
places
.
size
();
++
i
)
{
legal_actions_
.
push_back
(
LegalAction
::
place
(
places
[
i
]));
if
(
verbose_
)
{
auto
s
=
fmt
::
format
(
"{}: {}"
,
i
+
1
,
action_place_to_string
(
places
[
i
]));
players_
[
player
]
->
notify
(
s
);
}
}
to_play_
=
player
;
callback_
=
[
this
,
player
](
int
idx
)
{
int
y
=
player
+
1
;
std
::
string
spec
=
options_
[
idx
];
auto
plr
=
player
;
if
(
spec
[
0
]
==
'o'
)
{
auto
place
=
legal_actions_
[
idx
].
place_
;
int
i
=
static_cast
<
int
>
(
place
);
uint8_t
plr
=
player
;
uint8_t
loc
;
uint8_t
seq
;
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
MZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
MZone7
))
{
loc
=
LOCATION_MZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
MZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
SZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
SZone8
))
{
loc
=
LOCATION_SZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
SZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpMZone7
))
{
plr
=
1
-
player
;
spec
=
spec
.
substr
(
1
);
loc
=
LOCATION_MZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
OpMZone1
);
}
else
if
(
i
>=
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
)
&&
i
<=
static_cast
<
int
>
(
ActionPlace
::
OpSZone8
))
{
plr
=
1
-
player
;
loc
=
LOCATION_SZONE
;
seq
=
i
-
static_cast
<
int
>
(
ActionPlace
::
OpSZone1
);
}
auto
[
loc
,
seq
,
pos
]
=
spec_to_ls
(
spec
);
resp_buf_
[
0
]
=
plr
;
resp_buf_
[
1
]
=
loc
;
resp_buf_
[
2
]
=
seq
;
...
...
@@ -4620,7 +4856,7 @@ private:
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
}
// TODO: implement action
// TODO
(2)
: implement action
n_counters_
=
count
;
uint16_t
resp1
=
static_cast
<
uint16_t
>
(
std
::
min
(
counter_count
,
counters
[
0
]));
memcpy
(
resp_buf_
,
&
resp1
,
2
);
...
...
@@ -4644,19 +4880,15 @@ private:
" not implemented for announce number"
);
}
numbers
.
push_back
(
number
);
options_
.
push_back
(
std
::
string
(
1
,
'0'
+
number
));
legal_actions_
.
push_back
(
LegalAction
::
number
(
number
));
}
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
std
::
string
str
=
"Select a number, one of: ["
;
std
::
string
str
=
"Select a number, one of:"
;
pl
->
notify
(
str
);
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
str
+=
std
::
to_string
(
numbers
[
i
]);
if
(
i
<
count
-
1
)
{
str
+=
", "
;
}
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
+
1
,
numbers
[
i
]));
}
str
+=
"]"
;
pl
->
notify
(
str
);
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
...
...
@@ -4675,7 +4907,7 @@ private:
attrs
.
push_back
(
i
+
1
);
}
}
// TODO(2): implement action
if
(
count
!=
1
)
{
throw
std
::
runtime_error
(
"Announce attrib count "
+
std
::
to_string
(
count
)
+
" not implemented"
);
...
...
@@ -4686,40 +4918,28 @@ private:
pl
->
notify
(
"Select "
+
std
::
to_string
(
count
)
+
" attributes separated by spaces:"
);
for
(
int
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
pl
->
notify
(
std
::
to_string
(
attrs
[
i
])
+
": "
+
attribute_to_string
(
1
<<
(
attrs
[
i
]
-
1
)));
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
+
1
,
attribute_to_string
(
1
<<
(
attrs
[
i
]
-
1
))));
}
}
auto
combs
=
combinations
(
attrs
.
size
(),
count
);
for
(
const
auto
&
comb
:
combs
)
{
std
::
string
option
=
""
;
for
(
int
j
=
0
;
j
<
count
;
++
j
)
{
option
+=
std
::
to_string
(
attrs
[
comb
[
j
]]);
if
(
j
<
count
-
1
)
{
option
+=
" "
;
}
}
options_
.
push_back
(
option
);
// auto combs = combinations(attrs.size(), count);
for
(
int
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
legal_actions_
.
push_back
(
LegalAction
::
attribute
(
1
<<
(
attrs
[
i
]
-
1
)));
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
const
auto
&
option
=
op
tions_
[
idx
];
const
auto
&
action
=
legal_ac
tions_
[
idx
];
uint32_t
resp
=
0
;
int
i
=
0
;
while
(
i
<
option
.
size
())
{
resp
|=
1
<<
(
option
[
i
]
-
'1'
);
i
+=
2
;
}
resp
|=
action
.
attribute_
;
YGO_SetResponsei
(
pduel_
,
resp
);
};
}
else
if
(
msg_
==
MSG_SELECT_POSITION
)
{
// TODO: add card as feature
auto
player
=
read_u8
();
auto
code
=
read_u32
();
auto
valid_pos
=
read_u8
();
CardId
cid
=
c_get_card_id
(
code
);
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
...
...
@@ -4727,25 +4947,25 @@ private:
pl
->
notify
(
"Select position for "
+
card
.
name_
+
":"
);
}
std
::
vector
<
uint8_t
>
positions
;
int
i
=
1
;
for
(
auto
pos
:
{
POS_FACEUP_ATTACK
,
POS_FACEDOWN_ATTACK
,
POS_FACEUP_DEFENSE
,
POS_FACEDOWN_DEFENSE
})
{
if
(
valid_pos
&
pos
)
{
positions
.
push_back
(
pos
);
options_
.
push_back
(
std
::
to_string
(
i
));
LegalAction
la
;
la
.
cid_
=
cid
;
la
.
position_
=
pos
;
legal_actions_
.
push_back
(
la
);
int
cmd_idx
=
legal_actions_
.
size
();
if
(
verbose_
)
{
auto
pl
=
players_
[
player
];
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
i
,
position_to_string
(
pos
)));
pl
->
notify
(
fmt
::
format
(
"{}: {}"
,
cmd_idx
,
position_to_string
(
pos
)));
}
}
i
++
;
}
to_play_
=
player
;
callback_
=
[
this
](
int
idx
)
{
uint8_t
pos
=
options_
[
idx
][
0
]
-
'1'
;
YGO_SetResponsei
(
pduel_
,
1
<<
pos
);
uint8_t
pos
=
legal_actions_
[
idx
].
position_
;
YGO_SetResponsei
(
pduel_
,
pos
);
};
}
else
{
show_deck
(
0
);
...
...
@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
}
// namespace ygopro
template
<>
struct
fmt
::
formatter
<
ygopro
::
LegalAction
>:
formatter
<
string_view
>
{
// Format the LegalAction object
template
<
typename
FormatContext
>
auto
format
(
const
ygopro
::
LegalAction
&
action
,
FormatContext
&
ctx
)
const
{
std
::
stringstream
ss
;
ss
<<
"{"
;
if
(
!
action
.
spec_
.
empty
())
{
ss
<<
"spec='"
<<
action
.
spec_
<<
"', "
;
}
if
(
action
.
cid_
!=
0
)
{
ss
<<
"cid="
<<
action
.
cid_
<<
", "
;
}
if
(
action
.
act_
!=
ygopro
::
ActionAct
::
None
)
{
ss
<<
"act="
<<
ygopro
::
action_act_to_string
(
action
.
act_
)
<<
", "
;
}
if
(
action
.
phase_
!=
ygopro
::
ActionPhase
::
None
)
{
ss
<<
"phase="
<<
ygopro
::
action_phase_to_string
(
action
.
phase_
)
<<
", "
;
}
if
(
action
.
finish_
)
{
ss
<<
"finish=true, "
;
}
if
(
action
.
position_
!=
0
)
{
ss
<<
"position="
<<
ygopro
::
position_to_string
(
action
.
position_
)
<<
", "
;
}
if
(
action
.
effect_
!=
-
1
)
{
ss
<<
"effect="
<<
action
.
effect_
<<
", "
;
}
if
(
action
.
number_
!=
0
)
{
ss
<<
"number="
<<
int
(
action
.
number_
)
<<
", "
;
}
if
(
action
.
place_
!=
ygopro
::
ActionPlace
::
None
)
{
ss
<<
"place="
<<
ygopro
::
action_place_to_string
(
action
.
place_
)
<<
", "
;
}
if
(
action
.
attribute_
!=
0
)
{
ss
<<
"attribute="
<<
ygopro
::
attribute_to_string
(
action
.
attribute_
)
<<
", "
;
}
std
::
string
s
=
ss
.
str
();
if
(
s
.
back
()
==
' '
)
{
s
.
pop_back
();
s
.
pop_back
();
}
s
.
push_back
(
'}'
);
return
format_to
(
ctx
.
out
(),
"{}"
,
s
);
}
};
#endif // YGOENV_YGOPRO_YGOPRO_H_
ygoenv/ygoenv/ygopro0/__init__.py
0 → 100644
View file @
04e61b91
from
ygoenv.python.api
import
py_env
from
.ygopro0_ygoenv
import
(
_YGOPro0EnvPool
,
_YGOPro0EnvSpec
,
init_module
,
)
(
YGOPro0EnvSpec
,
YGOPro0DMEnvPool
,
YGOPro0GymEnvPool
,
YGOPro0GymnasiumEnvPool
,
)
=
py_env
(
_YGOPro0EnvSpec
,
_YGOPro0EnvPool
)
__all__
=
[
"YGOPro0EnvSpec"
,
"YGOPro0DMEnvPool"
,
"YGOPro0GymEnvPool"
,
"YGOPro0GymnasiumEnvPool"
,
]
ygoenv/ygoenv/ygopro0/registration.py
0 → 100644
View file @
04e61b91
from
ygoenv.registration
import
register
register
(
task_id
=
"YGOPro-v0"
,
import_path
=
"ygoenv.ygopro0"
,
spec_cls
=
"YGOPro0EnvSpec"
,
dm_cls
=
"YGOPro0DMEnvPool"
,
gym_cls
=
"YGOPro0GymEnvPool"
,
gymnasium_cls
=
"YGOPro0GymnasiumEnvPool"
,
)
ygoenv/ygoenv/ygopro0/ygopro.cpp
0 → 100644
View file @
04e61b91
#include "ygoenv/ygopro0/ygopro.h"
#include "ygoenv/core/py_envpool.h"
using
YGOPro0EnvSpec
=
PyEnvSpec
<
ygopro0
::
YGOProEnvSpec
>
;
using
YGOPro0EnvPool
=
PyEnvPool
<
ygopro0
::
YGOProEnvPool
>
;
PYBIND11_MODULE
(
ygopro0_ygoenv
,
m
)
{
REGISTER
(
m
,
YGOPro0EnvSpec
,
YGOPro0EnvPool
)
m
.
def
(
"init_module"
,
&
ygopro0
::
init_module
);
}
ygoenv/ygoenv/ygopro0/ygopro.h
0 → 100644
View file @
04e61b91
This source diff could not be displayed because it is too large. You can
view the blob
instead.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment