Commit 04e61b91 authored by sbl1996@126.com's avatar sbl1996@126.com

Add YGOPro-v1

parent f6139c17
# Action
## Types
- Set + card
- Reposition + card
- Special summon + card
- Summon Face-up Attack + card
- Summon Face-down Defense + card
- Attack + card
- DirectAttack + card
- Activate + card + effect
- Cancel
- Switch + phase
- SelectPosition + card + position
- AnnounceNumber + card + effect + number
- SelectPlace + card + place
- AnnounceAttrib + card + effect + attrib
## Effect
### MSG_SELECT_BATTLECMD | MSG_SELECT_IDLECMD | MSG_SELECT_CHAIN | MSG_SELECT_EFFECTYN
- desc == 0: default effect of card
- desc < LIMIT: system string
- desc > LIMIT: card + effect
### MSG_SELECT_OPTION | MSG_SELECT_YESNO
- desc == 0: error
- desc < LIMIT: system string
- desc > LIMIT: card + effect
...@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`. ...@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`.
## Legal Actions ## Legal Actions
- 0,1: spec index, uint16 -> 2 uint8 - 0: spec index
- 2: msg, discrete, 0: N/A, 1+: same as msg2str (15) - 1,2: code, uint16 -> 2 uint8
- 3: act, discrete (11) - 3: msg, discrete, 0: N/A, 1+: same as msg2str (15)
- 4: act, discrete (11)
- N/A - N/A
- t: Set - Set
- r: Reposition - Reposition
- c: Special Summon - Special Summon
- s: Summon Face-up Attack - Summon Face-up Attack
- m: Summon Face-down Defense - Summon Face-down Defense
- a: Attack - Attack
- v: Activate - DirectAttack
- v2: Activate the second effect - Activate
- v3: Activate the third effect - Cancel
- v4: Activate the fourth effect - 5: finish, discrete (2)
- 4: yes/no, discrete (3)
- N/A - N/A
- Yes - Finish
- No - 6: effect, discrete, 0: N/A
- 5: phase, discrete (4) - 7: phase, discrete (4)
- N/A - N/A
- Battle (b) - Battle (b)
- Main Phase 2 (m) - Main Phase 2 (m)
- End Phase (e) - End Phase (e)
- 6: cancel, discrete (2)
- N/A
- Cancel
- 7: finish, discrete (2)
- N/A
- Finish
- 8: position, discrete, 0: N/A, same as position2str - 8: position, discrete, 0: N/A, same as position2str
- 9: option, discrete, 0: N/A - 9: number, discrete, 0: N/A
- 10: number, discrete, 0: N/A - 10: place, discrete
- 11: place, discrete
- 0: N/A - 0: N/A
- 1-7: m - 1-7: m
- 8-15: s - 8-15: s
- 16-22: om - 16-22: om
- 23-30: os - 23-30: os
- 12: attribute, discrete, 0: N/A, same as attribute2id - 11: attribute, discrete, 0: N/A, same as attribute2id
## History Actions ## History Actions
- 0,1: card id, uint16 -> 2 uint8 - 0,1: card id, uint16 -> 2 uint8
- 2-12 same as legal actions - 2-11 same as legal actions
- 13: player, discrete, 0: me, 1: oppo - 12: turn, discrete, trunc to 3
- 14: turn, discrete, trunc to 3 - 13: phase, discrete (10)
...@@ -18,7 +18,7 @@ import flax ...@@ -18,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs from ygoai.rl.jax.agent import RNNAgent, ModelArgs
@dataclass @dataclass
......
...@@ -135,7 +135,7 @@ if __name__ == "__main__": ...@@ -135,7 +135,7 @@ if __name__ == "__main__":
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax import flax
from ygoai.rl.jax.agent2 import RNNAgent from ygoai.rl.jax.agent import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax")) cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
...@@ -168,7 +168,6 @@ if __name__ == "__main__": ...@@ -168,7 +168,6 @@ if __name__ == "__main__":
obs, infos = envs.reset() obs, infos = envs.reset()
print(obs)
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_) dones = np.zeros(num_envs, dtype=np.bool_)
......
...@@ -8,6 +8,24 @@ add_requires( ...@@ -8,6 +8,24 @@ add_requires(
"sqlitecpp 3.2.1") "sqlitecpp 3.2.1")
target("ygopro0_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro0/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core")
set_languages("c++17")
if is_mode("release") then
set_policy("build.optimization.lto", true)
add_cxxflags("-march=native")
end
add_includedirs("ygoenv")
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro0"
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
end)
target("ygopro_ygoenv") target("ygopro_ygoenv")
add_rules("python.library") add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp") add_files("ygoenv/ygoenv/ygopro/*.cpp")
...@@ -25,7 +43,6 @@ target("ygopro_ygoenv") ...@@ -25,7 +43,6 @@ target("ygopro_ygoenv")
print("Copy target to " .. install_target) print("Copy target to " .. install_target)
end) end)
target("edopro_ygoenv") target("edopro_ygoenv")
add_rules("python.library") add_rules("python.library")
add_files("ygoenv/ygoenv/edopro/*.cpp") add_files("ygoenv/ygoenv/edopro/*.cpp")
......
from typing import Tuple, Union, Optional from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial from functools import partial
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
default_embed_init = nn.initializers.uniform(scale=0.001) default_embed_init = nn.initializers.uniform(scale=0.001)
...@@ -14,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001) ...@@ -14,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001) default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module): class ActionEncoder(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module): ...@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module):
embed = partial( embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init) embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0]) x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1]) x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2]) x_a_yesno = embed(3, c // div)(x[:, :, 2])
...@@ -38,178 +46,353 @@ class ActionEncoder(nn.Module): ...@@ -38,178 +46,353 @@ class ActionEncoder(nn.Module):
x_a_number = embed(13, c // div // 2)(x[:, :, 8]) x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9]) x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10]) x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([ xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1) return xs
class Encoder(nn.Module): class ActionEncoderV1(nn.Module):
channels: int = 128 channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
c = self.channels c = self.channels
if self.embedding_shape is None: div = 8
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
embed = partial( embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init) nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
fc_embed = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype) embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype) x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(10, c // div)(x[:, :, 1])
x_a_finish = embed(3, c // div // 2)(x[:, :, 2])
x_a_effect = embed(256, c // div * 2)(x[:, :, 3])
x_a_phase = embed(4, c // div // 2)(x[:, :, 4])
x_a_position = embed(9, c // div)(x[:, :, 5])
x_a_number = embed(13, c // div // 2)(x[:, :, 6])
x_a_place = embed(31, c // div)(x[:, :, 7])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 8])
xs = [x_a_msg, x_a_act, x_a_finish, x_a_effect, x_a_phase,
x_a_position, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
version: int = 0
id_embed = embed(n_embed, embed_dim) @nn.compact
count_embed = embed(100, c // 16) def __call__(self, x_id, x):
hand_count_embed = embed(100, c // 16) c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = MLP((c // 8,), last_lin=False, dtype=jnp.float32, param_dtype=self.param_dtype) num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32) bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals)) num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) x1 = x[:, :, :10].astype(jnp.int32)
x_cards = x['cards_'] x2 = x[:, :, 10:].astype(jnp.float32)
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32) x_loc = x1[:, :, 0]
x_cards_2 = x_cards[:, :, 12:].astype(jnp.float32) x_seq = x1[:, :, 1]
x_id = decode_id(x_cards_1[:, :, :2]) if self.version == 0:
x_id = id_embed(x_id) x_id = mlp(
x_id = MLP( (c, c // 4), kernel_init=default_fc_init2)(x_id)
(c, c // 4), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id) x_id = layer_norm()(x_id)
f_loc = layer_norm()(embed(9, c)(x_loc))
f_seq = layer_norm()(embed(76, c)(x_seq))
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0 c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False) c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x_cards_1[:, :, 3]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x_cards_1[:, :, 4]) x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x_cards_1[:, :, 5]) x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x_cards_1[:, :, 6]) x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x_cards_1[:, :, 7]) x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x_cards_1[:, :, 8]) x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x_cards_1[:, :, 9]) x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x_cards_1[:, :, 10]) x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11]) x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x_cards_2[:, :, 0:2]) x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk) x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4]) x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def) x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:]) x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
x_feat = jnp.concatenate([ if self.version == 0:
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1) x_atk, x_def, x_type], axis=-1)
x_feat = layer_norm()(x_feat) x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = jnp.concatenate([x_id, x_feat], axis=-1)
f_cards = f_cards + f_loc + f_seq f_cards = f_cards + f_loc + f_seq
else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq)
x_cards = jnp.concatenate([
f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * x_id
f_cards = layer_norm()(x_cards)
return f_cards, c_mask
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
num_heads = max(2, c // 128) @nn.compact
for _ in range(self.num_card_layers): def __call__(self, x):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards) batch_size = x.shape[0]
na_card_embed = self.param( c = self.channels
'na_card_embed', mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02, layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
(1, c), self.param_dtype) embed = partial(
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype) nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1) fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(jnp.float32) count_embed = embed(100, c // 16)
x_g_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2])) hand_count_embed = embed(100, c // 16)
x_g_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32) num_fc = mlp((c // 8,), last_lin=False)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0]) bin_points, bin_intervals = make_bin_params(n_bins=32)
x_g_phase = embed(11, c // 8)(x_global_2[:, 1]) num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x_g_if_first = embed(2, c // 8)(x_global_2[:, 2])
x_g_is_my_turn = embed(2, c // 8)(x_global_2[:, 3])
x_global_3 = x_global[:, 8:22].astype(jnp.int32) x1 = x[:, :4].astype(jnp.float32)
x_g_cs = count_embed(x_global_3).reshape((batch_size, -1)) x2 = x[:, 4:8].astype(jnp.int32)
x_g_my_hand_c = hand_count_embed(x_global_3[:, 1]) x3 = x[:, 8:22].astype(jnp.int32)
x_g_op_hand_c = hand_count_embed(x_global_3[:, 8])
x_global = jnp.concatenate([ x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn, x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
f_cards = f_cards + jnp.expand_dims(f_global, 1) x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_actions = x_actions.astype(jnp.int32) x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
spec_index = decode_id(x_actions[..., :2]) x = jnp.concatenate([
B = jnp.arange(batch_size) x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
f_a_cards = f_cards[B[:, None], spec_index] x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
f_a_cards = f_a_cards + fc_layer(c)(layer_norm()(f_a_cards)) x = layer_norm()(x)
return x
x_a_feats = action_encoder(x_actions[..., 2:])
f_actions = f_a_cards + layer_norm()(x_a_feats)
a_mask = x_actions[:, :, 2] == 0 class Encoder(nn.Module):
a_mask = a_mask.at[:, 0].set(False) channels: int = 128
for _ in range(self.num_action_layers): num_layers: int = 2
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
f_actions, f_cards, dtype: Optional[jnp.dtype] = None
tgt_key_padding_mask=a_mask, param_dtype: jnp.dtype = jnp.float32
memory_key_padding_mask=c_mask) freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
noam: bool = False
version: int = 0
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
# Cards
f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
x_h_actions = x['h_actions_'].astype(jnp.int32) num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
if self.version == 0:
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0 h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False) h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2]) x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP( x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype, (c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id)) kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:]) f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = layer_norm()(x_h_id) + layer_norm()(x_h_a_feats)
f_h_actions = PositionalEncoding()(f_h_actions) f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_action_layers): for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask) f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c, dtype=jnp.float32)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
x_h_a_phase = embed(12, c // 2)(x_h_actions[:, :, 13])
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c, dtype=self.dtype)(x_h_a_feats)
if self.noam:
f_h_actions = LlamaEncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype,
rope=True, n_positions=64)(x_h_a_feats, src_key_padding_mask=h_mask)
else:
x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(
f_actions, f_h_actions,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask)
f_actions = layer_norm()(f_actions) # Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
f_s_cards_global = f_cards.mean(axis=1) if self.version == 0:
c_mask = 1 - a_mask[:, :, None].astype(f_actions.dtype) spec_index = decode_id(x_actions[..., :2])
f_s_actions_ha = (f_actions * c_mask).sum(axis=1) / c_mask.sum(axis=1) B = jnp.arange(batch_size)
f_state = jnp.concatenate([f_s_cards_global, f_s_actions_ha], axis=-1) f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
spec_index = x_actions[..., 0]
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
x_a_id = decode_id(x_actions[..., 1:3])
x_a_id = id_embed(x_a_id)
if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c, dtype=jnp.float32)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c, dtype=self.dtype)(f_actions)
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if self.use_history:
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid return f_actions, f_state, a_mask, valid
...@@ -219,54 +402,199 @@ class Actor(nn.Module): ...@@ -219,54 +402,199 @@ class Actor(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, f_actions, mask): def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01)) mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128) num_heads = max(2, c // 128)
f_actions = EncoderLayer( f_actions = get_encoder_layer_cls(
num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask) self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
logits = mlp((c // 4, 1), use_bias=True)(f_actions) f_actions, mask, a_s, a_b, o_s, o_b)
logits = logits[..., 0]
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits) logits = jnp.where(mask, big_neg, logits)
return logits return logits
class Critic(nn.Module): class Critic(nn.Module):
channels: int = 128 channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, f_state): def __call__(self, f_state):
c = self.channels f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(1.0)) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = MLP((c // 2, 1), use_bias=True)(f_state) x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x return x
class PPOAgent(nn.Module): def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
channels: int = 128 if main is not None:
num_card_layers: int = 2 rstate1, rstate2 = rstate
num_action_layers: int = 2 rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main)
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
version: int = 0
"""the version of the environment and the agent"""
class RNNAgent(nn.Module):
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
version: int = 0
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
encoder = Encoder( encoder = Encoder(
channels=self.channels, channels=c,
num_card_layers=self.num_card_layers, num_layers=self.num_layers,
num_action_layers=self.num_action_layers,
embedding_shape=self.embedding_shape, embedding_shape=self.embedding_shape,
dtype=self.dtype, dtype=self.dtype,
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
version=self.version,
) )
actor = Actor(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
critic = Critic(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_actions, mask)
value = critic(f_state) if self.rnn_type in ['lstm', 'none']:
return logits, value, valid rnn_layer = nn.OptimizedLSTMCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type is None:
rnn_layer = None
if rnn_layer is None:
f_state_r = f_state
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x_id, x):
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
return f_cards, c_mask
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = layer_norm()(x)
return x
class Encoder(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
noam: bool = False
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
# Cards
f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128)
f_actions = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main)
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
class RNNAgent(nn.Module):
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
)
f_actions, f_state, mask, valid = encoder(x)
if self.rnn_type in ['lstm', 'none']:
rnn_layer = nn.OptimizedLSTMCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type is None:
rnn_layer = None
if rnn_layer is None:
f_state_r = f_state
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
...@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False): ...@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
raise FileNotFoundError(f"Token deck not found: {token_deck}") raise FileNotFoundError(f"Token deck not found: {token_deck}")
decks["_tokens"] = str(token_deck) decks["_tokens"] = str(token_deck)
if 'YGOPro' in env_id: if 'YGOPro' in env_id:
if env_id == 'YGOPro-v1':
from ygoenv.ygopro import init_module from ygoenv.ygopro import init_module
elif env_id == 'YGOPro-v0':
from ygoenv.ygopro0 import init_module
else:
raise ValueError(f"Unknown YGOPro environment: {env_id}")
elif 'EDOPro' in env_id: elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks) init_module(str(db_path), code_list_file, decks)
......
...@@ -18,13 +18,16 @@ try: ...@@ -18,13 +18,16 @@ try:
except ImportError: except ImportError:
pass pass
try:
import ygoenv.ygopro0.registration # noqa: F401
except ImportError:
pass
try: try:
import ygoenv.edopro.registration # noqa: F401 import ygoenv.edopro.registration # noqa: F401
except ImportError: except ImportError:
pass pass
try: try:
import ygoenv.dummy.registration # noqa: F401 import ygoenv.dummy.registration # noqa: F401
except ImportError: except ImportError:
......
from ygoenv.registration import register from ygoenv.registration import register
register( register(
task_id="YGOPro-v0", task_id="YGOPro-v1",
import_path="ygoenv.ygopro", import_path="ygoenv.ygopro",
spec_cls="YGOProEnvSpec", spec_cls="YGOProEnvSpec",
dm_cls="YGOProDMEnvPool", dm_cls="YGOProDMEnvPool",
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <ankerl/unordered_dense.h> #include <ankerl/unordered_dense.h>
#include <unordered_set> #include <unordered_set>
#include "BS_thread_pool.h" #include "ygoenv/core/BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h" #include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h" #include "ygoenv/core/env.h"
...@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) { ...@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) {
} }
// system string // system string
static const ankerl::unordered_dense::map<int, std::string> system_strings = { static const std::map<int, std::string> system_strings = {
// announce type
{1050, "Monster"},
{1051, "Spell"},
{1052, "Trap"},
{1054, "Normal"},
{1055, "Effect"},
{1056, "Fusion"},
{1057, "Ritual"},
{1058, "Trap Monsters"},
{1059, "Spirit"},
{1060, "Union"},
{1061, "Gemini"},
{1062, "Tuner"},
{1063, "Synchro"},
{1064, "Token"},
{1066, "Quick-Play"},
{1067, "Continuous"},
{1068, "Equip"},
{1069, "Field"},
{1070, "Counter"},
{1071, "Flip"},
{1072, "Toon"},
{1073, "Xyz"},
{1074, "Pendulum"},
{1075, "Special Summon"},
{1076, "Link"},
{1080, "(N/A)"},
{1081, "Extra Monster Zone"},
// announce type end
// actions
{1150, "Activate"},
{1151, "Normal Summon"},
{1152, "Special Summon"},
{1153, "Set"},
{1154, "Flip Summon"},
{1155, "To Defense"},
{1156, "To Attack"},
{1157, "Attack"},
{1158, "View"},
{1159, "S/T Set"},
{1160, "Put in Pendulum Zone"},
{1161, "Do Effect"},
{1162, "Reset Effect"},
{1163, "Pendulum Summon"},
{1164, "Synchro Summon"},
{1165, "Xyz Summon"},
{1166, "Link Summon"},
{1167, "Tribute Summon"},
{1168, "Ritual Summon"},
{1169, "Fusion Summon"},
{1190, "Add to hand"},
{1191, "Send to GY"},
{1192, "Banish"},
{1193, "Return to Deck"},
// actions end
{1, "Normal Summon"},
{30, "Replay rules apply. Continue this attack?"}, {30, "Replay rules apply. Continue this attack?"},
{31, "Attack directly with this monster?"}, {31, "Attack directly with this monster?"},
{80, "Start Step of the Battle Phase."},
{81, "During the End Phase."},
{90, "Conduct this Normal Summon without Tributing?"},
{91, "Use additional Summon?"},
{92, "Tribute your opponent's monster?"},
{93, "Continue selecting Materials?"},
{94, "Activate this card's effect now?"},
{95, "Use the effect of [%ls]?"},
{96, "Use the effect of [%ls] to avoid destruction?"}, {96, "Use the effect of [%ls] to avoid destruction?"},
{97, "Place [%ls] to a Spell & Trap Zone?"},
{98, "Tribute a monster(s) your opponent controls?"},
{200, "From [%ls], activate [%ls]?"},
{203, "Chain another card or effect?"},
{210, "Continue selecting?"},
{218, "Pay LP by Effect of [%ls], instead?"},
{219, "Detach Xyz material by Effect of [%ls], instead?"},
{220, "Remove Counter(s) by Effect of [%ls], instead?"},
{221, "On [%ls], Activate Trigger Effect of [%ls]?"},
{222, "Activate Trigger Effect?"},
{221, "On [%ls], Activate Trigger Effect of [%ls]?"}, {221, "On [%ls], Activate Trigger Effect of [%ls]?"},
{1190, "Add to hand"},
{1192, "Banish"},
{1621, "Attack Negated"}, {1621, "Attack Negated"},
{1622, "[%ls] Missed timing"} {1622, "[%ls] Missed timing"}
}; };
...@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) { ...@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) {
if (it != system_strings.end()) { if (it != system_strings.end()) {
return it->second; return it->second;
} }
return "system string " + std::to_string(desc); throw std::runtime_error(
fmt::format("Cannot find system string: {}", desc));
// return "system string " + std::to_string(desc);
} }
static std::string ltrim(std::string s) { static std::string ltrim(std::string s) {
...@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) { ...@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) {
return s; return s;
} }
inline std::vector<std::string> flag_to_usable_cardspecs(uint32_t flag,
bool reverse = false) {
std::string zone_names[4] = {"m", "s", "om", "os"};
std::vector<std::string> specs;
for (int j = 0; j < 4; j++) {
uint32_t value = (flag >> (j * 8)) & 0xff;
for (int i = 0; i < 8; i++) {
bool avail = (value & (1 << i)) == 0;
if (reverse) {
avail = !avail;
}
if (avail) {
specs.push_back(zone_names[j] + std::to_string(i + 1));
}
}
}
return specs;
}
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) { inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
std::string spec; std::string spec;
...@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) { ...@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) {
loc = LOCATION_DECK; loc = LOCATION_DECK;
offset = 0; offset = 0;
} else { } else {
throw std::runtime_error("Invalid location"); std::string s = fmt::format("Invalid spec {}", spec);
throw std::runtime_error(s);
} }
int end = offset; int end = offset;
while (end < spec.size() && std::isdigit(spec[end])) { while (end < spec.size() && std::isdigit(spec[end])) {
...@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) { ...@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) {
return {loc, seq, pos}; return {loc, seq, pos};
} }
inline uint32_t ls_to_spec_code(uint8_t loc, uint8_t seq, uint8_t pos,
bool opponent) {
uint32_t c = opponent ? 1 : 0;
c |= (loc << 8);
c |= (seq << 16);
c |= (pos << 24);
return c;
}
inline uint32_t spec_to_code(const std::string &spec) { inline std::tuple<uint8_t, uint8_t, uint8_t, uint8_t>
spec_to_ls(uint8_t player, const std::string spec) {
uint8_t controller = player;
int offset = 0; int offset = 0;
bool opponent = false;
if (spec[0] == 'o') { if (spec[0] == 'o') {
opponent = true; controller = 1 - player;
offset++; offset++;
} }
auto [loc, seq, pos] = spec_to_ls(spec.substr(offset)); auto [loc, seq, pos] = spec_to_ls(spec.substr(offset));
return ls_to_spec_code(loc, seq, pos, opponent); return {controller, loc, seq, pos};
} }
inline std::string code_to_spec(uint32_t spec_code) {
uint8_t loc = (spec_code >> 8) & 0xff;
uint8_t seq = (spec_code >> 16) & 0xff;
uint8_t pos = (spec_code >> 24) & 0xff;
bool opponent = (spec_code & 0xff) == 1;
return ls_to_spec(loc, seq, pos, opponent);
}
static std::tuple<std::vector<uint32>, std::vector<uint32>, std::vector<uint32>> read_decks(const std::string &fp) { static std::tuple<std::vector<uint32>, std::vector<uint32>, std::vector<uint32>> read_decks(const std::string &fp) {
std::ifstream file(fp); std::ifstream file(fp);
...@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \ ...@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \
return "unknown"; \ return "unknown"; \
} }
static const ankerl::unordered_dense::map<int, uint8_t> system_string2id =
make_ids(system_strings, 16);
DEFINE_X_TO_ID_FUN(system_string_to_id, system_string2id)
static const std::map<uint8_t, std::string> location2str = { static const std::map<uint8_t, std::string> location2str = {
{LOCATION_DECK, "Deck"}, {LOCATION_DECK, "Deck"},
{LOCATION_HAND, "Hand"}, {LOCATION_HAND, "Hand"},
...@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id = ...@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id =
DEFINE_X_TO_ID_FUN(msg_to_id, msg2id) DEFINE_X_TO_ID_FUN(msg_to_id, msg2id)
static const ankerl::unordered_dense::map<char, uint8_t> cmd_act2id = enum class ActionAct {
make_ids({'t', 'r', 'c', 's', 'm', 'a', 'v'}, 1); None,
DEFINE_X_TO_ID_FUN(cmd_act_to_id, cmd_act2id) Set,
Repo,
SpSummon,
Summon,
MSet,
Attack,
DirectAttack,
Activate,
Cancel,
};
inline std::string action_act_to_string(ActionAct act) {
switch (act) {
case ActionAct::None:
return "None";
case ActionAct::Set:
return "Set";
case ActionAct::Repo:
return "Repo";
case ActionAct::SpSummon:
return "SpSummon";
case ActionAct::Summon:
return "Summon";
case ActionAct::MSet:
return "MSet";
case ActionAct::Attack:
return "Attack";
case ActionAct::DirectAttack:
return "DirectAttack";
case ActionAct::Activate:
return "Activate";
case ActionAct::Cancel:
return "Cancel";
default:
return "Unknown";
}
}
static const ankerl::unordered_dense::map<char, uint8_t> cmd_phase2id = enum class ActionPhase {
make_ids(std::vector<char>({'b', 'm', 'e'}), 1); None,
DEFINE_X_TO_ID_FUN(cmd_phase_to_id, cmd_phase2id) Battle,
Main2,
End,
};
inline std::string action_phase_to_string(ActionPhase phase) {
switch (phase) {
case ActionPhase::None:
return "None";
case ActionPhase::Battle:
return "Battle";
case ActionPhase::Main2:
return "Main2";
case ActionPhase::End:
return "End";
default:
return "Unknown";
}
}
static const ankerl::unordered_dense::map<char, uint8_t> cmd_yesno2id = enum class ActionPlace {
make_ids(std::vector<char>({'y', 'n'}), 1); None,
DEFINE_X_TO_ID_FUN(cmd_yesno_to_id, cmd_yesno2id) MZone1,
MZone2,
MZone3,
MZone4,
MZone5,
MZone6,
MZone7,
SZone1,
SZone2,
SZone3,
SZone4,
SZone5,
SZone6,
SZone7,
SZone8,
OpMZone1,
OpMZone2,
OpMZone3,
OpMZone4,
OpMZone5,
OpMZone6,
OpMZone7,
OpSZone1,
OpSZone2,
OpSZone3,
OpSZone4,
OpSZone5,
OpSZone6,
OpSZone7,
OpSZone8,
};
static const ankerl::unordered_dense::map<std::string, uint8_t> cmd_place2id = inline std::vector<ActionPlace> flag_to_usable_places(
make_ids(std::vector<std::string>( uint32_t flag, bool reverse = false) {
{"m1", "m2", "m3", "m4", "m5", "m6", "m7", "s1", std::vector<ActionPlace> places;
"s2", "s3", "s4", "s5", "s6", "s7", "s8", "om1", for (int j = 0; j < 4; j++) {
"om2", "om3", "om4", "om5", "om6", "om7", "os1", "os2", uint32_t value = (flag >> (j * 8)) & 0xff;
"os3", "os4", "os5", "os6", "os7", "os8"}), for (int i = 0; i < 8; i++) {
1); bool avail = (value & (1 << i)) == 0;
DEFINE_X_TO_ID_FUN(cmd_place_to_id, cmd_place2id) if (reverse) {
avail = !avail;
}
if (avail) {
ActionPlace place;
if (j == 0) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::MZone1));
} else if (j == 1) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::SZone1));
} else if (j == 2) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::OpMZone1));
} else if (j == 3) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::OpSZone1));
}
places.push_back(place);
}
}
}
return places;
}
inline std::string action_place_to_string(ActionPlace place) {
int i = static_cast<int>(place);
if (i == 0) {
return "None";
}
else if (i >= static_cast<int>(ActionPlace::MZone1) && i <= static_cast<int>(ActionPlace::MZone7)) {
return fmt::format("m{}", i - static_cast<int>(ActionPlace::MZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::SZone1) && i <= static_cast<int>(ActionPlace::SZone8)) {
return fmt::format("s{}", i - static_cast<int>(ActionPlace::SZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::OpMZone1) && i <= static_cast<int>(ActionPlace::OpMZone7)) {
return fmt::format("om{}", i - static_cast<int>(ActionPlace::OpMZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::OpSZone1) && i <= static_cast<int>(ActionPlace::OpSZone8)) {
return fmt::format("os{}", i - static_cast<int>(ActionPlace::OpSZone1) + 1);
}
else {
return "Unknown";
}
}
inline std::pair<uint8_t, uint8_t> float_transform(int x) { inline std::pair<uint8_t, uint8_t> float_transform(int x) {
...@@ -807,6 +978,89 @@ using PlayerId = uint8_t; ...@@ -807,6 +978,89 @@ using PlayerId = uint8_t;
using CardCode = uint32_t; using CardCode = uint32_t;
using CardId = uint16_t; using CardId = uint16_t;
const int DESCRIPTION_LIMIT = 10000;
const int CARD_EFFECT_OFFSET = 10010;
class LegalAction {
public:
std::string spec_ = "";
ActionAct act_ = ActionAct::None;
ActionPhase phase_ = ActionPhase::None;
bool finish_ = false;
uint8_t position_ = 0;
int effect_ = -1;
uint8_t number_ = 0;
ActionPlace place_ = ActionPlace::None;
uint8_t attribute_ = 0;
int spec_index_ = 0;
CardId cid_ = 0;
int msg_ = 0;
static LegalAction from_spec(const std::string &spec) {
LegalAction la;
la.spec_ = spec;
return la;
}
static LegalAction act_spec(ActionAct act, const std::string &spec) {
LegalAction la;
la.act_ = act;
la.spec_ = spec;
return la;
}
static LegalAction finish() {
LegalAction la;
la.finish_ = true;
return la;
}
static LegalAction cancel() {
LegalAction la;
la.act_ = ActionAct::Cancel;
return la;
}
static LegalAction activate_spec(int effect_idx, const std::string &spec) {
LegalAction la;
la.act_ = ActionAct::Activate;
la.effect_ = effect_idx;
la.spec_ = spec;
return la;
}
static LegalAction phase(ActionPhase phase) {
LegalAction la;
la.phase_ = phase;
return la;
}
static LegalAction number(uint8_t number) {
LegalAction la;
la.number_ = number;
return la;
}
static LegalAction place(ActionPlace place) {
LegalAction la;
la.place_ = place;
return la;
}
static LegalAction attribute(int attribute) {
LegalAction la;
la.attribute_ = attribute;
return la;
}
};
class SpecInfo {
public:
uint16_t index;
CardId cid;
};
class Card { class Card {
friend class YGOProEnv; friend class YGOProEnv;
...@@ -874,42 +1128,23 @@ public: ...@@ -874,42 +1128,23 @@ public:
return get_spec(player != controler_); return get_spec(player != controler_);
} }
uint32_t get_spec_code(PlayerId player) const {
return ls_to_spec_code(location_, sequence_, position_,
player != controler_);
}
std::string get_position() const { return position_to_string(position_); } std::string get_position() const { return position_to_string(position_); }
std::string get_effect_description(uint32_t desc, std::string get_effect_description(CardCode code, int effect_idx) const {
bool existing = false) const { if (code == 0) {
std::string s; return get_system_string(effect_idx);
bool e = false;
auto code = code_;
if (desc > 10000) {
code = desc >> 4;
}
uint32_t offset = desc - code_ * 16;
bool in_range = (offset >= 0) && (offset < strings_.size());
std::string str = "";
if (in_range) {
str = ltrim(strings_[offset]);
}
if (in_range || desc == 0) {
if ((desc == 0) || str.empty()) {
s = "Activate " + name_ + ".";
} else {
s = name_ + " (" + str + ")";
e = true;
} }
} else { if (effect_idx == 0) {
s = get_system_string(desc); return "default";
if (!s.empty()) {
e = true;
} }
effect_idx -= CARD_EFFECT_OFFSET;
if (effect_idx < 0) {
throw std::runtime_error(
fmt::format("Invalid effect index: {}", effect_idx));
} }
if (existing && !e) { auto s = strings_[effect_idx];
s = ""; if (s.empty()) {
return "effect " + std::to_string(effect_idx);
} }
return s; return s;
} }
...@@ -1222,7 +1457,7 @@ public: ...@@ -1222,7 +1457,7 @@ public:
const int &init_lp() const { return init_lp_; } const int &init_lp() const { return init_lp_; }
virtual int think(const std::vector<std::string> &options) = 0; virtual int think(const std::vector<LegalAction> &actions) = 0;
}; };
class GreedyAI : public Player { class GreedyAI : public Player {
...@@ -1232,7 +1467,7 @@ public: ...@@ -1232,7 +1467,7 @@ public:
bool verbose = false) bool verbose = false)
: Player(nickname, init_lp, duel_player, verbose) {} : Player(nickname, init_lp, duel_player, verbose) {}
int think(const std::vector<std::string> &options) override { return 0; } int think(const std::vector<LegalAction> &actions) override { return 0; }
}; };
class RandomAI : public Player { class RandomAI : public Player {
...@@ -1246,8 +1481,8 @@ public: ...@@ -1246,8 +1481,8 @@ public:
: Player(nickname, init_lp, duel_player, verbose), gen_(seed), : Player(nickname, init_lp, duel_player, verbose), gen_(seed),
dist_(0, max_options - 1) {} dist_(0, max_options - 1) {}
int think(const std::vector<std::string> &options) override { int think(const std::vector<LegalAction> &actions) override {
return dist_(gen_) % options.size(); return dist_(gen_) % actions.size();
} }
}; };
...@@ -1258,17 +1493,17 @@ public: ...@@ -1258,17 +1493,17 @@ public:
bool verbose = false) bool verbose = false)
: Player(nickname, init_lp, duel_player, verbose) {} : Player(nickname, init_lp, duel_player, verbose) {}
int think(const std::vector<std::string> &options) override { int think(const std::vector<LegalAction> &actions) override {
while (true) { while (true) {
std::string input = getline(); std::string input = getline();
if (input == "quit") { if (input == "quit") {
exit(0); exit(0);
} }
auto it = std::find(options.begin(), options.end(), input); int idx = std::stoi(input) - 1;
if (it != options.end()) { if (idx >= 0 && idx < actions.size()) {
return std::distance(options.begin(), it); return idx;
} else { } else {
fmt::println("{} Choose from {}", duel_player_, options); fmt::println("{} Choose from {} actions", duel_player_, actions.size());
} }
} }
} }
...@@ -1286,7 +1521,7 @@ public: ...@@ -1286,7 +1521,7 @@ public:
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
int n_action_feats = 13; int n_action_feats = 12;
return MakeDict( return MakeDict(
"obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})), "obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"obs:global_"_.Bind(Spec<uint8_t>({23})), "obs:global_"_.Bind(Spec<uint8_t>({23})),
...@@ -1393,7 +1628,7 @@ protected: ...@@ -1393,7 +1628,7 @@ protected:
int turn_count_; int turn_count_;
int msg_; int msg_;
std::vector<std::string> options_; std::vector<LegalAction> legal_actions_;
PlayerId to_play_; PlayerId to_play_;
std::function<void(int)> callback_; std::function<void(int)> callback_;
...@@ -1423,9 +1658,10 @@ protected: ...@@ -1423,9 +1658,10 @@ protected:
const int n_history_actions_; const int n_history_actions_;
// circular buffer for history actions // circular buffer for history actions
TArray<uint8_t> history_actions_; TArray<uint8_t> history_actions_1_;
int ha_p_ = 0; TArray<uint8_t> history_actions_2_;
std::vector<CardId> h_card_ids_; int ha_p_1_ = 0;
int ha_p_2_ = 0;
std::unordered_set<std::string> revealed_; std::unordered_set<std::string> revealed_;
...@@ -1487,8 +1723,9 @@ public: ...@@ -1487,8 +1723,9 @@ public:
int max_options = spec.config["max_options"_]; int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1]; int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
h_card_ids_.resize(max_options); history_actions_1_ = TArray<uint8_t>(Array(
history_actions_ = TArray<uint8_t>(Array( ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
history_actions_2_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2}))); ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
} }
...@@ -1560,8 +1797,10 @@ public: ...@@ -1560,8 +1797,10 @@ public:
turn_count_ = 0; turn_count_ = 0;
ms_idx_ = -1; ms_idx_ = -1;
history_actions_.Zero(); history_actions_1_.Zero();
ha_p_ = 0; history_actions_2_.Zero();
ha_p_1_ = 0;
ha_p_2_ = 0;
clock_t _start = clock(); clock_t _start = clock();
...@@ -1720,7 +1959,7 @@ public: ...@@ -1720,7 +1959,7 @@ public:
if (ms_mode_ == 0) { if (ms_mode_ == 0) {
for (int j = 0; j < ms_specs_.size(); ++j) { for (int j = 0; j < ms_specs_.size(); ++j) {
const auto &spec = ms_specs_[j]; const auto &spec = ms_specs_[j];
options_.push_back(spec); legal_actions_.push_back(LegalAction::from_spec(spec));
} }
} else { } else {
ms_combs_ = combs; ms_combs_ = combs;
...@@ -1729,22 +1968,23 @@ public: ...@@ -1729,22 +1968,23 @@ public:
} }
void handle_multi_select() { void handle_multi_select() {
options_ = {}; legal_actions_.clear();
if (ms_mode_ == 0) { if (ms_mode_ == 0) {
for (int j = 0; j < ms_specs_.size(); ++j) { for (int j = 0; j < ms_specs_.size(); ++j) {
if (ms_spec2idx_.find(ms_specs_[j]) != ms_spec2idx_.end()) { if (ms_spec2idx_.find(ms_specs_[j]) != ms_spec2idx_.end()) {
options_.push_back(ms_specs_[j]); legal_actions_.push_back(
LegalAction::from_spec(ms_specs_[j]));
} }
} }
if (ms_idx_ == ms_max_ - 1) { if (ms_idx_ == ms_max_ - 1) {
if (ms_idx_ >= ms_min_) { if (ms_idx_ >= ms_min_) {
options_.push_back("f"); legal_actions_.push_back(LegalAction::finish());
} }
callback_ = [this](int idx) { callback_ = [this](int idx) {
_callback_multi_select(idx, true); _callback_multi_select(idx, true);
}; };
} else if (ms_idx_ >= ms_min_) { } else if (ms_idx_ >= ms_min_) {
options_.push_back("f"); legal_actions_.push_back(LegalAction::finish());
callback_ = [this](int idx) { callback_ = [this](int idx) {
_callback_multi_select(idx, false); _callback_multi_select(idx, false);
}; };
...@@ -1766,7 +2006,7 @@ public: ...@@ -1766,7 +2006,7 @@ public:
if (it != ms_spec2idx_.end()) { if (it != ms_spec2idx_.end()) {
return it->second; return it->second;
} }
// TODO: find the root cause // TODO(2): find the root cause
// print ms_spec2idx // print ms_spec2idx
show_deck(0); show_deck(0);
show_deck(1); show_deck(1);
...@@ -1783,11 +2023,15 @@ public: ...@@ -1783,11 +2023,15 @@ public:
} }
void _callback_multi_select_2(int idx) { void _callback_multi_select_2(int idx) {
const auto &option = options_[idx]; const auto &action = legal_actions_[idx];
idx = get_ms_spec_idx(option); idx = get_ms_spec_idx(action.spec_);
if (idx == -1) { if (idx == -1) {
// TODO: find the root cause // TODO(2): find the root cause
fmt::println("options: {}, idx: {}, option: {}", options_, idx, option); std::vector<std::string> specs;
for (const auto &la : legal_actions_) {
specs.push_back(la.spec_);
}
fmt::println("specs: {}, idx: {}, spec: {}", specs, idx, action.spec_);
throw std::runtime_error("Spec not found"); throw std::runtime_error("Spec not found");
} }
ms_r_idxs_.push_back(idx); ms_r_idxs_.push_back(idx);
...@@ -1814,7 +2058,7 @@ public: ...@@ -1814,7 +2058,7 @@ public:
} }
for (auto &i : comb) { for (auto &i : comb) {
const auto &spec = ms_specs_[i]; const auto &spec = ms_specs_[i];
options_.push_back(spec); legal_actions_.push_back(LegalAction::from_spec(spec));
} }
} }
...@@ -1831,17 +2075,21 @@ public: ...@@ -1831,17 +2075,21 @@ public:
} }
void _callback_multi_select(int idx, bool finish) { void _callback_multi_select(int idx, bool finish) {
const auto &option = options_[idx]; const auto &action = legal_actions_[idx];
// fmt::println("Select card: {}, finish: {}", option, finish); // fmt::println("Select card: {}, finish: {}", option, finish);
if (option == "f") { if (action.finish_) {
finish = true; finish = true;
} else { } else {
idx = get_ms_spec_idx(option); idx = get_ms_spec_idx(action.spec_);
if (idx != -1) { if (idx != -1) {
ms_r_idxs_.push_back(idx); ms_r_idxs_.push_back(idx);
} else { } else {
// TODO: find the root cause // TODO(2): find the root cause
fmt::println("options: {}, idx: {}, option: {}", options_, idx, option); std::vector<std::string> specs;
for (const auto &la : legal_actions_) {
specs.push_back(la.spec_);
}
fmt::println("specs: {}, idx: {}, spec: {}", specs, idx, action.spec_);
ms_idx_ = -1; ms_idx_ = -1;
resp_buf_[0] = ms_min_; resp_buf_[0] = ms_min_;
for (int i = 0; i < ms_min_; ++i) { for (int i = 0; i < ms_min_; ++i) {
...@@ -1860,27 +2108,27 @@ public: ...@@ -1860,27 +2108,27 @@ public:
YGO_SetResponseb(pduel_, resp_buf_); YGO_SetResponseb(pduel_, resp_buf_);
} else { } else {
ms_idx_++; ms_idx_++;
ms_spec2idx_.erase(option); ms_spec2idx_.erase(action.spec_);
} }
} }
void update_h_card_ids(PlayerId player, int idx) { void update_history_actions(PlayerId player, const LegalAction& action) {
h_card_ids_[idx] = parse_card_id(options_[idx], player); if (action.act_ == ActionAct::Cancel) {
}
void update_history_actions(PlayerId player, int idx) {
if ((msg_ == MSG_SELECT_CHAIN) & (options_[idx][0] == 'c')) {
return; return;
} }
ha_p_--; auto& ha_p = player == 0 ? ha_p_1_ : ha_p_2_;
if (ha_p_ < 0) { auto& history_actions = player == 0 ? history_actions_1_ : history_actions_2_;
ha_p_ = n_history_actions_ - 1; ha_p--;
if (ha_p < 0) {
ha_p = n_history_actions_ - 1;
} }
history_actions_[ha_p_].Zero(); history_actions[ha_p].Zero();
_set_obs_action(history_actions_, ha_p_, msg_, options_[idx], {}, _set_obs_action(history_actions, ha_p, action);
h_card_ids_[idx]); // Spec index not available in history actions
history_actions_[ha_p_](13) = static_cast<uint8_t>(player); history_actions[ha_p](0) = 0;
history_actions_[ha_p_](14) = static_cast<uint8_t>(turn_count_); // history_actions[ha_p](12) = static_cast<uint8_t>(player);
history_actions[ha_p](12) = static_cast<uint8_t>(turn_count_);
history_actions[ha_p](13) = static_cast<uint8_t>(phase_to_id(current_phase_));
} }
void show_deck(const std::vector<CardCode> &deck, const std::string &prefix) const { void show_deck(const std::vector<CardCode> &deck, const std::string &prefix) const {
...@@ -1910,18 +2158,18 @@ public: ...@@ -1910,18 +2158,18 @@ public:
} }
void show_history_actions(PlayerId player) const { void show_history_actions(PlayerId player) const {
const auto &ha = history_actions_; const auto &ha = player == 0 ? history_actions_1_ : history_actions_2_;
// print card ids of history actions // print card ids of history actions
for (int i = 0; i < n_history_actions_; ++i) { for (int i = 0; i < n_history_actions_; ++i) {
fmt::print("history {}\n", i); fmt::print("history {}\n", i);
uint8_t msg_id = uint8_t(ha(i, 2)); uint8_t msg_id = uint8_t(ha(i, 3));
int msg = _msgs[msg_id - 1]; int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg)); fmt::print("msg: {},", msg_to_string(msg));
uint8_t v1 = ha(i, 0); uint8_t v1 = ha(i, 1);
uint8_t v2 = ha(i, 1); uint8_t v2 = ha(i, 2);
CardId card_id = (static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2); CardId card_id = (static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2);
fmt::print(" {};", card_id); fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) { for (int j = 4; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j))); fmt::print(" {}", uint8_t(ha(i, j)));
} }
fmt::print("\n"); fmt::print("\n");
...@@ -1933,7 +2181,7 @@ public: ...@@ -1933,7 +2181,7 @@ public:
int idx = action["action"_]; int idx = action["action"_];
callback_(idx); callback_(idx);
update_history_actions(to_play_, idx); update_history_actions(to_play_, legal_actions_[idx]);
PlayerId player = to_play_; PlayerId player = to_play_;
...@@ -2012,10 +2260,10 @@ public: ...@@ -2012,10 +2260,10 @@ public:
} }
private: private:
using SpecIndex = ankerl::unordered_dense::map<std::string, uint16_t>; using SpecInfos = ankerl::unordered_dense::map<std::string, SpecInfo>;
std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) { std::tuple<SpecInfos, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecIndex spec2index; SpecInfos spec_infos;
std::vector<int> loc_n_cards; std::vector<int> loc_n_cards;
int offset = 0; int offset = 0;
for (auto pi = 0; pi < 2; pi++) { for (auto pi = 0; pi < 2; pi++) {
...@@ -2054,18 +2302,23 @@ private: ...@@ -2054,18 +2302,23 @@ private:
hide = false; hide = false;
} }
} }
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
_set_obs_card_(f_cards, offset, c, hide); _set_obs_card_(f_cards, offset, c, hide);
offset++; offset++;
spec2index[spec] = static_cast<uint16_t>(offset);
spec_infos[spec] = {static_cast<uint16_t>(offset), card_id};
} }
} }
} }
} }
return {spec2index, loc_n_cards}; return {spec_infos, loc_n_cards};
} }
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c, void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide) { bool hide, CardId card_id = 0) {
// check offset exceeds max_cards // check offset exceeds max_cards
uint8_t location = c.location_; uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY; bool overlay = location & LOCATION_OVERLAY;
...@@ -2077,7 +2330,6 @@ private: ...@@ -2077,7 +2330,6 @@ private:
} }
if (!hide) { if (!hide) {
auto card_id = c_get_card_id(c.code_);
f_cards(offset, 0) = static_cast<uint8_t>(card_id >> 8); f_cards(offset, 0) = static_cast<uint8_t>(card_id >> 8);
f_cards(offset, 1) = static_cast<uint8_t>(card_id & 0xff); f_cards(offset, 1) = static_cast<uint8_t>(card_id & 0xff);
} }
...@@ -2148,17 +2400,10 @@ private: ...@@ -2148,17 +2400,10 @@ private:
} }
} }
void _set_obs_action_spec(TArray<uint8_t> &feat, int i, const SpecInfo& find_spec_info(SpecInfos &spec_infos, const std::string &spec) {
const std::string &spec, auto it = spec_infos.find(spec);
const SpecIndex &spec2index, if (it == spec_infos.end()) {
CardId card_id = 0) { // TODO(2): find the root cause
uint16_t idx;
if (spec2index.empty()) {
idx = card_id;
} else {
auto it = spec2index.find(spec);
if (it == spec2index.end()) {
// TODO: find the root cause
// print spec2index // print spec2index
show_deck(0); show_deck(0);
show_deck(1); show_deck(1);
...@@ -2166,135 +2411,111 @@ private: ...@@ -2166,135 +2411,111 @@ private:
show_turn(); show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_); fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec); fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec2index) { for (auto &[k, v] : spec_infos) {
fmt::print("{}: {}, ", k, v); fmt::print("{}: {} {}, ", k, v.index, v.cid);
} }
fmt::print("\n"); fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec); // throw std::runtime_error("Spec not found: " + spec);
idx = 1; spec_infos[spec] = {1, 1};
} else { return spec_infos[spec];
idx = it->second;
}
}
feat(i, 0) = static_cast<uint8_t>(idx >> 8);
feat(i, 1) = static_cast<uint8_t>(idx & 0xff);
} }
return it->second;
void _set_obs_action_msg(TArray<uint8_t> &feat, int i, int msg) {
feat(i, 2) = msg_to_id(msg);
} }
void _set_obs_action_act(TArray<uint8_t> &feat, int i, char act, void _set_obs_action_spec(
uint8_t act_offset = 0) { TArray<uint8_t> &feat, int i, int idx) {
feat(i, 3) = cmd_act_to_id(act) + act_offset; feat(i, 0) = static_cast<uint8_t>(idx);
} }
void _set_obs_action_yesno(TArray<uint8_t> &feat, int i, char yesno) { void _set_obs_action_card_id(
feat(i, 4) = cmd_yesno_to_id(yesno); TArray<uint8_t> &feat, int i, CardId cid) {
feat(i, 1) = static_cast<uint8_t>(cid >> 8);
feat(i, 2) = static_cast<uint8_t>(cid & 0xff);
} }
void _set_obs_action_phase(TArray<uint8_t> &feat, int i, char phase) { void _set_obs_action_msg(TArray<uint8_t> &feat, int i, int msg) {
feat(i, 5) = cmd_phase_to_id(phase); feat(i, 3) = msg_to_id(msg);
} }
void _set_obs_action_cancel(TArray<uint8_t> &feat, int i) { void _set_obs_action_act(TArray<uint8_t> &feat, int i, ActionAct act) {
feat(i, 6) = 1; feat(i, 4) = static_cast<uint8_t>(act);
} }
void _set_obs_action_finish(TArray<uint8_t> &feat, int i) { void _set_obs_action_finish(TArray<uint8_t> &feat, int i) {
feat(i, 7) = 1; feat(i, 5) = 1;
}
void _set_obs_action_effect(TArray<uint8_t> &feat, int i, int effect) {
// 0: None
// 1: default
// 2-15: card effect
// 16+: system
if (effect == -1) {
effect = 0;
} else if (effect == 0) {
effect = 1;
} else if (effect >= CARD_EFFECT_OFFSET) {
effect = effect - CARD_EFFECT_OFFSET + 2;
} else {
effect = system_string_to_id(effect);
}
feat(i, 6) = static_cast<uint8_t>(effect);
} }
void _set_obs_action_position(TArray<uint8_t> &feat, int i, char position) { void _set_obs_action_phase(TArray<uint8_t> &feat, int i, ActionPhase phase){
position = 1 << (position - '1'); feat(i, 7) = static_cast<uint8_t>(phase);
feat(i, 8) = position_to_id(position);
} }
void _set_obs_action_option(TArray<uint8_t> &feat, int i, char option) { void _set_obs_action_position(TArray<uint8_t> &feat, int i, uint8_t position) {
feat(i, 9) = option - '0'; feat(i, 8) = position_to_id(position);
} }
void _set_obs_action_number(TArray<uint8_t> &feat, int i, char number) { void _set_obs_action_number(TArray<uint8_t> &feat, int i, uint8_t number) {
feat(i, 10) = number - '0'; feat(i, 9) = number;
} }
void _set_obs_action_place(TArray<uint8_t> &feat, int i, const std::string &spec) { void _set_obs_action_place(TArray<uint8_t> &feat, int i, ActionPlace place) {
feat(i, 11) = cmd_place_to_id(spec); feat(i, 10) = static_cast<uint8_t>(place);
} }
void _set_obs_action_attrib(TArray<uint8_t> &feat, int i, uint8_t attrib) { void _set_obs_action_attrib(TArray<uint8_t> &feat, int i, uint8_t attrib) {
feat(i, 12) = attribute_to_id(attrib); feat(i, 11) = attribute_to_id(attrib);
} }
void _set_obs_action(TArray<uint8_t> &feat, int i, int msg, void _set_obs_action(TArray<uint8_t> &feat, int i, const LegalAction &action) {
const std::string &option, const SpecIndex &spec2index, auto msg = action.msg_;
CardId card_id) {
_set_obs_action_msg(feat, i, msg); _set_obs_action_msg(feat, i, msg);
if (msg == MSG_SELECT_IDLECMD) { _set_obs_action_card_id(feat, i, action.cid_);
if (option == "b" || option == "e") { if (msg == MSG_SELECT_CARD || msg == MSG_SELECT_TRIBUTE ||
_set_obs_action_phase(feat, i, option[0]);
} else {
auto act = option[0];
auto spec = option.substr(2);
uint8_t offset = 0;
int n = spec.size();
if (act == 'v' && std::isalpha(spec[n - 1])) {
offset = spec[n - 1] - 'a';
spec = spec.substr(0, n - 1);
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CHAIN) {
if (option[0] == 'c') {
_set_obs_action_cancel(feat, i);
} else {
char act = 'v';
auto spec = option;
uint8_t offset = 0;
auto n = spec.size();
if (std::isalpha(spec[n - 1])) {
offset = spec[n - 1] - 'a';
spec = spec.substr(0, n - 1);
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CARD || msg == MSG_SELECT_TRIBUTE ||
msg == MSG_SELECT_SUM || msg == MSG_SELECT_UNSELECT_CARD) { msg == MSG_SELECT_SUM || msg == MSG_SELECT_UNSELECT_CARD) {
if (option[0] == 'f') { if (action.finish_) {
_set_obs_action_finish(feat, i); _set_obs_action_finish(feat, i);
} else { } else {
_set_obs_action_spec(feat, i, option, spec2index, card_id); _set_obs_action_spec(feat, i, action.spec_index_);
} }
} else if (msg == MSG_SELECT_POSITION) { } else if (msg == MSG_SELECT_POSITION) {
_set_obs_action_position(feat, i, option[0]); _set_obs_action_position(feat, i, action.position_);
} else if (msg == MSG_SELECT_EFFECTYN) { } else if (msg == MSG_SELECT_EFFECTYN) {
auto spec = option.substr(2); _set_obs_action_spec(feat, i, action.spec_index_);
_set_obs_action_spec(feat, i, spec, spec2index, card_id); _set_obs_action_act(feat, i, action.act_);
_set_obs_action_effect(feat, i, action.effect_);
_set_obs_action_yesno(feat, i, option[0]); } else if (msg == MSG_SELECT_YESNO || msg == MSG_SELECT_OPTION) {
} else if (msg == MSG_SELECT_YESNO) { _set_obs_action_act(feat, i, action.act_);
_set_obs_action_yesno(feat, i, option[0]); _set_obs_action_effect(feat, i, action.effect_);
} else if (msg == MSG_SELECT_BATTLECMD) { } else if (
if (option == "m" || option == "e") { msg == MSG_SELECT_BATTLECMD ||
_set_obs_action_phase(feat, i, option[0]); msg == MSG_SELECT_IDLECMD ||
} else { msg == MSG_SELECT_CHAIN) {
auto act = option[0]; _set_obs_action_phase(feat, i, action.phase_);
auto spec = option.substr(2); _set_obs_action_spec(feat, i, action.spec_index_);
_set_obs_action_act(feat, i, act); _set_obs_action_act(feat, i, action.act_);
_set_obs_action_spec(feat, i, spec, spec2index, card_id); _set_obs_action_effect(feat, i, action.effect_);
}
} else if (msg == MSG_SELECT_OPTION) {
_set_obs_action_option(feat, i, option[0]);
} else if (msg == MSG_SELECT_PLACE || msg_ == MSG_SELECT_DISFIELD) { } else if (msg == MSG_SELECT_PLACE || msg_ == MSG_SELECT_DISFIELD) {
_set_obs_action_place(feat, i, option); _set_obs_action_place(feat, i, action.place_);
} else if (msg == MSG_ANNOUNCE_ATTRIB) { } else if (msg == MSG_ANNOUNCE_ATTRIB) {
_set_obs_action_attrib(feat, i, 1 << (option[0] - '1')); _set_obs_action_attrib(feat, i, action.attribute_);
} else if (msg == MSG_ANNOUNCE_NUMBER) { } else if (msg == MSG_ANNOUNCE_NUMBER) {
_set_obs_action_number(feat, i, option[0]); _set_obs_action_number(feat, i, action.number_);
} else { } else {
throw std::runtime_error("Unsupported message " + std::to_string(msg)); throw std::runtime_error("Unsupported message " + std::to_string(msg));
} }
...@@ -2302,49 +2523,42 @@ private: ...@@ -2302,49 +2523,42 @@ private:
CardId spec_to_card_id(const std::string &spec, PlayerId player) { CardId spec_to_card_id(const std::string &spec, PlayerId player) {
int offset = 0; int offset = 0;
// TODO: possible info leak bool opponent = false;
if (spec[0] == 'o') { if (spec[0] == 'o') {
player = 1 - player; player = 1 - player;
opponent = true;
offset++; offset++;
} }
auto [loc, seq, pos] = spec_to_ls(spec.substr(offset)); auto [loc, seq, pos] = spec_to_ls(spec.substr(offset));
return c_get_card_id(get_card_code(player, loc, seq)); if (opponent) {
} bool hidden_for_opponent = true;
if (
CardId parse_card_id(const std::string &option, PlayerId player) { loc == LOCATION_MZONE || loc == LOCATION_SZONE ||
CardId card_id = 0; loc == LOCATION_GRAVE || loc == LOCATION_REMOVED) {
if (msg_ == MSG_SELECT_IDLECMD) { hidden_for_opponent = false;
if (!(option == "b" || option == "e")) {
auto n = option.size();
if (std::isalpha(option[n - 1])) {
card_id = spec_to_card_id(option.substr(2, n - 3), player);
} else {
card_id = spec_to_card_id(option.substr(2), player);
} }
if (revealed_.size() != 0) {
hidden_for_opponent = false;
} }
} else if (msg_ == MSG_SELECT_CHAIN) { if (hidden_for_opponent) {
if (option != "c") { return 0;
card_id = spec_to_card_id(option, player);
} }
} else if (msg_ == MSG_SELECT_CARD || msg_ == MSG_SELECT_TRIBUTE || Card c = get_card(player, loc, seq);
msg_ == MSG_SELECT_SUM || msg_ == MSG_SELECT_UNSELECT_CARD) { bool hide = c.position_ & POS_FACEDOWN;
if (option[0] != 'f') { if (revealed_.find(spec) != revealed_.end()) {
card_id = spec_to_card_id(option, player); hide = false;
} }
} else if (msg_ == MSG_SELECT_EFFECTYN) { CardId card_id = 0;
card_id = spec_to_card_id(option.substr(2), player); if (!hide) {
} else if (msg_ == MSG_SELECT_BATTLECMD) { card_id = c_get_card_id(c.code_);
if (!(option == "m" || option == "e")) {
card_id = spec_to_card_id(option.substr(2), player);
} }
} }
return card_id; return c_get_card_id(get_card_code(player, loc, seq));
} }
void _set_obs_actions(TArray<uint8_t> &feat, const SpecIndex &spec2index, void _set_obs_actions(TArray<uint8_t> &feat, const std::vector<LegalAction> &actions) {
int msg, const std::vector<std::string> &options) { for (int i = 0; i < actions.size(); ++i) {
for (int i = 0; i < options.size(); ++i) { _set_obs_action(feat, i, actions[i]);
_set_obs_action(feat, i, msg, options[i], spec2index, 0);
} }
} }
...@@ -2451,7 +2665,7 @@ private: ...@@ -2451,7 +2665,7 @@ private:
void WriteState(float reward, int win_reason = 0) { void WriteState(float reward, int win_reason = 0) {
State state = Allocate(); State state = Allocate();
int n_options = options_.size(); int n_options = legal_actions_.size();
state["reward"_] = reward; state["reward"_] = reward;
state["info:to_play"_] = int(to_play_); state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay); state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
...@@ -2463,62 +2677,69 @@ private: ...@@ -2463,62 +2677,69 @@ private:
return; return;
} }
auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_); auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards); _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback // we can't shuffle because idx must be stable in callback
if (n_options > max_options()) { if (n_options > max_options()) {
options_.resize(max_options()); legal_actions_.resize(max_options());
} }
// print spec2index n_options = legal_actions_.size();
// for (auto const& [key, val] : spec2index) {
// fmt::println("{} {}", key, val);
// }
_set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
n_options = options_.size();
state["info:num_options"_] = n_options; state["info:num_options"_] = n_options;
// update_h_card_ids from state
for (int i = 0; i < n_options; ++i) { for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0); auto &action = legal_actions_[i];
uint8_t spec_index2 = state["obs:actions_"_](i, 1); action.msg_ = msg_;
uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2); const auto &spec = action.spec_;
if (spec_index == 0) { if (!spec.empty()) {
h_card_ids_[i] = 0; const auto& spec_info = find_spec_info(spec_infos, spec);
} else { action.spec_index_ = spec_info.index;
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0); if (action.cid_ == 0) {
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1); action.cid_ = spec_info.cid;
h_card_ids_[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2); }
} }
} }
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions // write history actions
int offset = n_history_actions_ - ha_p_; auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
int n_h_action_feats = history_actions_.Shape()[1]; auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign( state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions_[ha_p_].Data(), n_h_action_feats * offset); (uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign( state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions_.Data(), n_h_action_feats * ha_p_); (uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) { for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 2)) == 0) { if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break; break;
} }
state["obs:h_actions_"_](i, 13) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 13)) == to_play_); // state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 14))); int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 14) = static_cast<uint8_t>(turn_diff); state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
} }
} }
void show_decision(int idx) { void show_decision(int idx) {
fmt::println("Player {} chose \"{}\" in {}", to_play_, options_[idx], std::string s;
options_); const auto& a = legal_actions_[idx];
if (!a.spec_.empty()) {
s = a.spec_;
} else if (a.place_ != ActionPlace::None) {
s = action_place_to_string(a.place_);
} else if (a.position_ != 0) {
s = position_to_string(a.position_);
} else {
s = fmt::format("{}", a);
}
fmt::print("Player {} chose \"{}\" in {}\n", to_play_, s, legal_actions_);
} }
std::tuple<std::vector<CardCode>, std::vector<CardCode>, std::string> std::tuple<std::vector<CardCode>, std::vector<CardCode>, std::string>
...@@ -2581,15 +2802,19 @@ private: ...@@ -2581,15 +2802,19 @@ private:
handle_multi_select(); handle_multi_select();
} else { } else {
handle_message(); handle_message();
if (options_.empty()) { if (legal_actions_.empty()) {
continue; continue;
} }
} }
if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) { if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) {
if (options_.size() == 1) { if (legal_actions_.size() == 1) {
callback_(0); callback_(0);
update_h_card_ids(to_play_, 0); auto la = legal_actions_[0];
update_history_actions(to_play_, 0); la.msg_ = msg_;
if (la.cid_ == 0 && !la.spec_.empty()) {
la.cid_ = spec_to_card_id(la.spec_, to_play_);
}
update_history_actions(to_play_, la);
if (verbose_) { if (verbose_) {
show_decision(0); show_decision(0);
} }
...@@ -2597,7 +2822,7 @@ private: ...@@ -2597,7 +2822,7 @@ private:
return; return;
} }
} else { } else {
auto idx = players_[to_play_]->think(options_); auto idx = players_[to_play_]->think(legal_actions_);
callback_(idx); callback_(idx);
if (verbose_) { if (verbose_) {
show_decision(idx); show_decision(idx);
...@@ -2606,7 +2831,7 @@ private: ...@@ -2606,7 +2831,7 @@ private:
} }
} }
done_ = true; done_ = true;
options_.clear(); legal_actions_.clear();
} }
uint8_t read_u8() { return data_[dp_++]; } uint8_t read_u8() { return data_[dp_++]; }
...@@ -2653,7 +2878,12 @@ private: ...@@ -2653,7 +2878,12 @@ private:
int32_t bl = YGO_QueryCard(pduel_, player, loc, seq, flags, query_buf_); int32_t bl = YGO_QueryCard(pduel_, player, loc, seq, flags, query_buf_);
qdp_ = 0; qdp_ = 0;
if (bl <= 0) { if (bl <= 0) {
throw std::runtime_error("[get_card] Invalid card (bl <= 0)"); show_deck(0);
show_deck(1);
show_turn();
show_buffer();
auto s = fmt::format("[get_card] Invalid card (bl <= 0), player: {}, loc: {}, seq: {}", player, loc, seq);
throw std::runtime_error(s);
} }
uint32_t f = q_read_u32(); uint32_t f = q_read_u32();
if (f == LEN_EMPTY) { if (f == LEN_EMPTY) {
...@@ -2728,7 +2958,7 @@ private: ...@@ -2728,7 +2958,7 @@ private:
c.attack_ = q_read_u32(); c.attack_ = q_read_u32();
c.defense_ = q_read_u32(); c.defense_ = q_read_u32();
// TODO: equip_target // TODO(2): equip_target
if (f & QUERY_EQUIP_CARD) { if (f & QUERY_EQUIP_CARD) {
q_read_u32(); q_read_u32();
} }
...@@ -2744,7 +2974,7 @@ private: ...@@ -2744,7 +2974,7 @@ private:
cards.push_back(c_); cards.push_back(c_);
} }
// TODO: counters // TODO(2): counters
uint32_t n_counters = q_read_u32(); uint32_t n_counters = q_read_u32();
for (int i = 0; i < n_counters; ++i) { for (int i = 0; i < n_counters; ++i) {
if (i == 0) { if (i == 0) {
...@@ -2803,7 +3033,7 @@ private: ...@@ -2803,7 +3033,7 @@ private:
auto controller = read_u8(); auto controller = read_u8();
auto loc = read_u8(); auto loc = read_u8();
auto seq = read_u8(); auto seq = read_u8();
uint32_t data = -1; uint32_t data = 0;
if (extra) { if (extra) {
if (extra8) { if (extra8) {
data = read_u8(); data = read_u8();
...@@ -2816,6 +3046,23 @@ private: ...@@ -2816,6 +3046,23 @@ private:
return card_specs; return card_specs;
} }
std::tuple<CardCode, int> unpack_desc(CardCode code, uint32_t desc) {
if (desc < DESCRIPTION_LIMIT) {
return {0, desc};
}
CardCode code_ = desc >> 4;
int idx = desc & 0xf;
if (idx < 0 || idx >= 14) {
fmt::print("Code: {}, Code_: {}, Desc: {}\n", code, code_, desc);
show_deck(0);
show_deck(1);
show_buffer();
show_turn();
throw std::runtime_error("Invalid effect index: " + std::to_string(idx));
}
return {code_, idx + CARD_EFFECT_OFFSET};
}
std::string cardlist_info_for_player(const Card &card, PlayerId pl) { std::string cardlist_info_for_player(const Card &card, PlayerId pl) {
std::string spec = card.get_spec(pl); std::string spec = card.get_spec(pl);
if (card.location_ == LOCATION_DECK) { if (card.location_ == LOCATION_DECK) {
...@@ -2833,7 +3080,7 @@ private: ...@@ -2833,7 +3080,7 @@ private:
// 3. update to_play_ and options_ if need action // 3. update to_play_ and options_ if need action
void handle_message() { void handle_message() {
msg_ = int(data_[dp_++]); msg_ = int(data_[dp_++]);
options_ = {}; legal_actions_ = {};
if (verbose_) { if (verbose_) {
fmt::println("Message {}, length {}, dp {}", msg_to_string(msg_), dl_, dp_); fmt::println("Message {}, length {}, dp {}", msg_to_string(msg_), dl_, dp_);
...@@ -3097,11 +3344,11 @@ private: ...@@ -3097,11 +3344,11 @@ private:
uint8_t pos = read_u8(); uint8_t pos = read_u8();
uint8_t type = read_u8(); uint8_t type = read_u8();
uint32_t value = read_u32(); uint32_t value = read_u32();
if (type == CHINT_RACE) {
Card card = get_card(player, loc, seq); Card card = get_card(player, loc, seq);
if (card.code_ == 0) { if (card.code_ == 0) {
return; return;
} }
if (type == CHINT_RACE) {
std::string races_str = "TODO"; std::string races_str = "TODO";
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
players_[pl]->notify(fmt::format("{} ({}) selected {}.", players_[pl]->notify(fmt::format("{} ({}) selected {}.",
...@@ -3109,6 +3356,10 @@ private: ...@@ -3109,6 +3356,10 @@ private:
races_str)); races_str));
} }
} else if (type == CHINT_ATTRIBUTE) { } else if (type == CHINT_ATTRIBUTE) {
Card card = get_card(player, loc, seq);
if (card.code_ == 0) {
return;
}
std::string attributes_str = "TODO"; std::string attributes_str = "TODO";
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
players_[pl]->notify(fmt::format("{} ({}) selected {}.", players_[pl]->notify(fmt::format("{} ({}) selected {}.",
...@@ -3229,7 +3480,7 @@ private: ...@@ -3229,7 +3480,7 @@ private:
return; return;
} }
dp_ += 6; dp_ += 6;
// TODO: implement output // TODO(3): implement output
} else if (msg_ == MSG_CARD_TARGET) { } else if (msg_ == MSG_CARD_TARGET) {
if (!verbose_) { if (!verbose_) {
dp_ = dl_; dp_ = dl_;
...@@ -3301,7 +3552,7 @@ private: ...@@ -3301,7 +3552,7 @@ private:
players_[pl]->notify(str); players_[pl]->notify(str);
} }
} else if (msg_ == MSG_SORT_CARD) { } else if (msg_ == MSG_SORT_CARD) {
// TODO: implement action // TODO(3): implement action
if (!verbose_) { if (!verbose_) {
dp_ = dl_; dp_ = dl_;
resp_buf_[0] = 255; resp_buf_[0] = 255;
...@@ -3374,7 +3625,7 @@ private: ...@@ -3374,7 +3625,7 @@ private:
auto pl = players_[player]; auto pl = players_[player];
PlayerId op_id = 1 - player; PlayerId op_id = 1 - player;
auto op = players_[op_id]; auto op = players_[op_id];
// TODO: counter type to string // TODO(3): counter type to string
pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player))); pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id))); op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
} else if (msg_ == MSG_REMOVE_COUNTER) { } else if (msg_ == MSG_REMOVE_COUNTER) {
...@@ -3406,7 +3657,7 @@ private: ...@@ -3406,7 +3657,7 @@ private:
dp_ = dl_; dp_ = dl_;
return; return;
} }
// TODO: implement output // TODO(3): implement output
dp_ = dl_; dp_ = dl_;
} else if (msg_ == MSG_SHUFFLE_DECK) { } else if (msg_ == MSG_SHUFFLE_DECK) {
if (!verbose_) { if (!verbose_) {
...@@ -3699,52 +3950,64 @@ private: ...@@ -3699,52 +3950,64 @@ private:
if (verbose_) { if (verbose_) {
pl->notify("Battle menu:"); pl->notify("Battle menu:");
} }
for (const auto [code, spec, data] : activatable) { for (const auto [code_t, spec, desc] : activatable) {
// TODO: Add effect description to indicate which effect is being activated CardCode code = code_t;
options_.push_back("v " + spec); if(code & 0x80000000) {
code &= 0x7fffffff;
}
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
legal_actions_.push_back(la);
if (verbose_) { if (verbose_) {
auto [loc, seq, pos] = spec_to_ls(spec); auto c = c_get_card(code);
auto c = get_card(player, loc, seq); int cmd_idx = legal_actions_.size();
pl->notify("v " + spec + ": activate " + c.name_ + " (" + std::string s = fmt::format(
std::to_string(c.attack_) + "/" + "{}: activate {}({}) [{}/{}] ({})",
std::to_string(c.defense_) + ")"); cmd_idx, c.name_, spec, c.attack_, c.defense_, c.get_effect_description(code_d, eff_idx));
} }
} }
for (const auto [code, spec, data] : attackable) { for (const auto [code, spec, data] : attackable) {
// TODO: add this as feature
bool direct_attackable = data & 0x1; bool direct_attackable = data & 0x1;
options_.push_back("a " + spec); auto act = direct_attackable ? ActionAct::DirectAttack : ActionAct::Attack;
legal_actions_.push_back(
LegalAction::act_spec(act, spec));
if (verbose_) { if (verbose_) {
auto [loc, seq, pos] = spec_to_ls(spec); auto [controller, loc, seq, pos] = spec_to_ls(player, spec);
auto c = get_card(player, loc, seq); auto c = get_card(controller, loc, seq);
std::string s; int cmd_idx = legal_actions_.size();
auto attack_str = direct_attackable ? "direct attack" : "attack";
std::string s = fmt::format(
"{}: {} {}({}) ", cmd_idx, attack_str, c.name_, spec);
if (c.type_ & TYPE_LINK) { if (c.type_ & TYPE_LINK) {
s = "a " + spec + ": " + c.name_ + " (" + s += fmt::format("[{}]", c.attack_);
std::to_string(c.attack_) + ")";
} else {
s = "a " + spec + ": " + c.name_ + " (" +
std::to_string(c.attack_) + "/" +
std::to_string(c.defense_) + ")";
}
if (direct_attackable) {
s += " direct attack";
} else { } else {
s += " attack"; s += fmt::format("[{}/{}]", c.attack_, c.defense_);
} }
pl->notify(s); pl->notify(s);
} }
} }
if (to_m2) { if (to_m2) {
options_.push_back("m"); legal_actions_.push_back(
LegalAction::phase(ActionPhase::Main2));
int cmd_idx = legal_actions_.size();
if (verbose_) { if (verbose_) {
pl->notify("m: Main phase 2."); pl->notify(fmt::format("{}: Main phase 2.", cmd_idx));
} }
} }
if (to_ep) { if (to_ep) {
if (!to_m2) { if (!to_m2) {
options_.push_back("e"); legal_actions_.push_back(
LegalAction::phase(ActionPhase::End));
int cmd_idx = legal_actions_.size();
if (verbose_) { if (verbose_) {
pl->notify("e: End phase."); pl->notify(fmt::format("{}: End phase.", cmd_idx));
} }
} }
} }
...@@ -3752,14 +4015,15 @@ private: ...@@ -3752,14 +4015,15 @@ private:
int n_attackables = attackable.size(); int n_attackables = attackable.size();
to_play_ = player; to_play_ = player;
callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) { callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) {
const auto &la = legal_actions_[idx];
if (idx < n_activatables) { if (idx < n_activatables) {
YGO_SetResponsei(pduel_, idx << 16); YGO_SetResponsei(pduel_, idx << 16);
} else if (idx < (n_activatables + n_attackables)) { } else if (idx < (n_activatables + n_attackables)) {
idx = idx - n_activatables; idx = idx - n_activatables;
YGO_SetResponsei(pduel_, (idx << 16) + 1); YGO_SetResponsei(pduel_, (idx << 16) + 1);
} else if ((options_[idx] == "e") && to_ep) { } else if ((la.phase_ == ActionPhase::End) && to_ep) {
YGO_SetResponsei(pduel_, 3); YGO_SetResponsei(pduel_, 3);
} else if ((options_[idx] == "m") && to_m2) { } else if ((la.phase_ == ActionPhase::Main2) && to_m2) {
YGO_SetResponsei(pduel_, 2); YGO_SetResponsei(pduel_, 2);
} else { } else {
throw std::runtime_error("Invalid option"); throw std::runtime_error("Invalid option");
...@@ -3777,21 +4041,18 @@ private: ...@@ -3777,21 +4041,18 @@ private:
std::vector<std::string> select_specs; std::vector<std::string> select_specs;
select_specs.reserve(select_size); select_specs.reserve(select_size);
if (verbose_) { if (verbose_) {
std::vector<Card> cards; auto pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:");
for (int i = 0; i < select_size; ++i) { for (int i = 0; i < select_size; ++i) {
auto code = read_u32(); auto code = read_u32();
auto loc = read_u32(); auto loc = read_u32();
Card card = c_get_card(code); Card card = c_get_card(code);
card.set_location(loc); card.set_location(loc);
cards.push_back(card);
}
auto pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:");
for (const auto &card : cards) {
auto spec = card.get_spec(player); auto spec = card.get_spec(player);
select_specs.push_back(spec); select_specs.push_back(spec);
pl->notify(spec + ": " + card.name_); auto s = fmt::format("{}: {}({})", i + 1, card.name_, spec);
pl->notify(s);
} }
} else { } else {
for (int i = 0; i < select_size; ++i) { for (int i = 0; i < select_size; ++i) {
...@@ -3807,22 +4068,22 @@ private: ...@@ -3807,22 +4068,22 @@ private:
auto unselect_size = read_u8(); auto unselect_size = read_u8();
// unselect not allowed (no regrets!) // unselect not allowed (no regrets)
dp_ += 8 * unselect_size; dp_ += 8 * unselect_size;
for (int j = 0; j < select_specs.size(); ++j) { for (int j = 0; j < select_specs.size(); ++j) {
options_.push_back(select_specs[j]); legal_actions_.push_back(LegalAction::from_spec(select_specs[j]));
} }
if (finishable) { if (finishable) {
options_.push_back("f"); legal_actions_.push_back(LegalAction::finish());
} }
// cancelable and finishable not needed // cancelable and finishable not needed
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (options_[idx] == "f") { if (legal_actions_[idx].finish_) {
YGO_SetResponsei(pduel_, -1); YGO_SetResponsei(pduel_, -1);
} else { } else {
resp_buf_[0] = 1; resp_buf_[0] = 1;
...@@ -3893,7 +4154,7 @@ private: ...@@ -3893,7 +4154,7 @@ private:
} }
} }
// TODO: use this when added to history actions // TODO(1): use this when added to history actions
// if ((min == max) && (max == specs.size())) { // if ((min == max) && (max == specs.size())) {
// resp_buf_[0] = specs.size(); // resp_buf_[0] = specs.size();
// for (int i = 0; i < specs.size(); ++i) { // for (int i = 0; i < specs.size(); ++i) {
...@@ -3974,7 +4235,7 @@ private: ...@@ -3974,7 +4235,7 @@ private:
// combs = combinations_with_weight(release_params, min); // combs = combinations_with_weight(release_params, min);
} }
// TODO: use this when added to history actions // TODO(1): use this when added to history actions
// if (max == specs.size()) { // if (max == specs.size()) {
// // tribute all // // tribute all
// resp_buf_[0] = specs.size(); // resp_buf_[0] = specs.size();
...@@ -4126,25 +4387,18 @@ private: ...@@ -4126,25 +4387,18 @@ private:
// auto hint_timing = read_u32(); // auto hint_timing = read_u32();
// auto other_timing = read_u32(); // auto other_timing = read_u32();
std::vector<Card> cards; std::vector<CardCode> codes;
std::vector<uint32_t> descs; std::vector<uint32_t> descs;
std::vector<uint32_t> spec_codes; std::vector<std::string> specs;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
auto et = read_u8(); auto flag = read_u8();
CardCode code = read_u32(); CardCode code = read_u32();
if (verbose_) { codes.push_back(code);
uint32_t loc = read_u32();
Card card = c_get_card(code);
card.set_location(loc);
cards.push_back(card);
spec_codes.push_back(card.get_spec_code(player));
} else {
PlayerId c = read_u8(); PlayerId c = read_u8();
uint8_t loc = read_u8(); uint8_t loc = read_u8();
uint8_t seq = read_u8(); uint8_t seq = read_u8();
uint8_t pos = read_u8(); uint8_t pos = read_u8();
spec_codes.push_back(ls_to_spec_code(loc, seq, pos, c != player)); specs.push_back(ls_to_spec(loc, seq, pos, c != player));
}
uint32_t desc = read_u32(); uint32_t desc = read_u32();
descs.push_back(desc); descs.push_back(desc);
} }
...@@ -4168,58 +4422,42 @@ private: ...@@ -4168,58 +4422,42 @@ private:
op->seen_waiting_ = true; op->seen_waiting_ = true;
} }
std::vector<int> chain_index;
ankerl::unordered_dense::map<uint32_t, int> chain_counts;
ankerl::unordered_dense::map<uint32_t, int> chain_orders;
std::vector<std::string> chain_specs;
std::vector<std::string> effect_descs;
for (int i = 0; i < size; i++) {
chain_index.push_back(i);
chain_counts[spec_codes[i]] += 1;
}
for (int i = 0; i < size; i++) {
auto spec_code = spec_codes[i];
auto cs = code_to_spec(spec_code);
auto chain_count = chain_counts[spec_code];
if (chain_count > 1) {
// TODO: should use desc to indicate activate which effect
cs.push_back('a' + chain_orders[spec_code]);
}
chain_orders[spec_code]++;
chain_specs.push_back(cs);
if (verbose_) {
const auto &card = cards[i];
effect_descs.push_back(card.get_effect_description(descs[i], true));
}
}
if (verbose_) { if (verbose_) {
if (forced) {
pl->notify("Select chain:"); pl->notify("Select chain:");
} else {
pl->notify("Select chain (c to cancel):");
} }
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
const auto &effect_desc = effect_descs[i]; CardCode code = codes[i];
if (effect_desc.empty()) { uint32_t desc = descs[i];
pl->notify(chain_specs[i] + ": " + cards[i].name_); auto spec = specs[i];
} else { auto [code_d, eff_idx] = unpack_desc(code, desc);
pl->notify(chain_specs[i] + " (" + cards[i].name_ + if (desc == 0) {
"): " + effect_desc); code_d = code;
} }
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
} }
legal_actions_.push_back(la);
if (verbose_) {
auto c = c_get_card(code);
std::string s = fmt::format(
"{}: {}({}) ({})",
i + 1, c.name_, spec, c.get_effect_description(code_d, eff_idx));
pl->notify(s);
} }
for (const auto &spec : chain_specs) {
options_.push_back(spec);
} }
if (!forced) { if (!forced) {
options_.push_back("c"); legal_actions_.push_back(LegalAction::cancel());
if (verbose_) {
pl->notify(fmt::format("{}: cancel", size + 1));
}
} }
to_play_ = player; to_play_ = player;
callback_ = [this, forced](int idx) { callback_ = [this, forced](int idx) {
const auto &option = options_[idx]; const auto &action = legal_actions_[idx];
if (option == "c") { if (action.act_ == ActionAct::Cancel) {
if (forced) { if (forced) {
fmt::print("cancel not allowed in forced chain\n"); fmt::print("cancel not allowed in forced chain\n");
YGO_SetResponsei(pduel_, 0); YGO_SetResponsei(pduel_, 0);
...@@ -4232,58 +4470,76 @@ private: ...@@ -4232,58 +4470,76 @@ private:
}; };
} else if (msg_ == MSG_SELECT_YESNO) { } else if (msg_ == MSG_SELECT_YESNO) {
auto player = read_u8(); auto player = read_u8();
if (verbose_) {
auto desc = read_u32(); auto desc = read_u32();
auto pl = players_[player]; auto [code, eff_idx] = unpack_desc(0, desc);
std::string opt; if (desc == 0) {
if (desc > 10000) { show_buffer();
auto code = desc >> 4; auto s = fmt::format("Unknown desc {} in select_yesno", desc);
auto card = c_get_card(code); throw std::runtime_error(s);
auto opt_idx = desc & 0xf;
if (opt_idx < card.strings_.size()) {
opt = card.strings_[opt_idx];
} }
if (opt.empty()) { auto la = LegalAction::activate_spec(eff_idx, "");
opt = "Unknown question from " + card.name_ + ". Yes or no?"; if (code != 0) {
la.cid_ = c_get_card_id(code);
} }
legal_actions_.push_back(la);
if (verbose_) {
auto pl = players_[player];
std::string s;
if (code == 0) {
s = get_system_string(eff_idx);
} else { } else {
opt = get_system_string(desc); Card c = c_get_card(code);
int cmd_idx = legal_actions_.size();
eff_idx -= CARD_EFFECT_OFFSET;
if (eff_idx >= c.strings_.size()) {
throw std::runtime_error(
fmt::format("Unknown effect {} of {}", eff_idx, c.name_));
} }
pl->notify(opt); auto str = c.strings_[eff_idx];
pl->notify("Please enter y or n."); if (str.empty()) {
} else { str = "effect " + std::to_string(eff_idx);
dp_ += 4; }
s = fmt::format("{} ({})", c.name_, str);
}
pl->notify("1: " + s);
pl->notify("2: No");
} }
options_ = {"y", "n"}; // TODO: maybe add card id to cancel
legal_actions_.push_back(LegalAction::cancel());
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
YGO_SetResponsei(pduel_, 1); YGO_SetResponsei(pduel_, 1);
} else if (idx == 1) { } else if (idx == 1) {
YGO_SetResponsei(pduel_, 0); YGO_SetResponsei(pduel_, 0);
} else {
throw std::runtime_error("Invalid option");
} }
}; };
} else if (msg_ == MSG_SELECT_EFFECTYN) { } else if (msg_ == MSG_SELECT_EFFECTYN) {
auto player = read_u8(); auto player = read_u8();
std::string spec;
if (verbose_) {
CardCode code = read_u32(); CardCode code = read_u32();
uint32_t loc = read_u32(); auto ct = read_u8();
Card card = c_get_card(code); auto loc = read_u8();
card.set_location(loc); auto seq = read_u8();
auto pos = read_u8();
auto desc = read_u32(); auto desc = read_u32();
std::string spec = ls_to_spec(loc, seq, pos, ct != player);
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
legal_actions_.push_back(la);
if (verbose_) {
Card c = c_get_card(code);
auto pl = players_[player]; auto pl = players_[player];
spec = card.get_spec(player); auto name = c.name_;
auto name = card.name_;
std::string s; std::string s;
if (desc == 0) { if (code_d == 0) {
// From [%ls], activate [%ls]?
s = "From " + card.get_spec(player) + ", activate " + name + "?";
} else if (desc < 2048) {
s = get_system_string(desc); s = get_system_string(desc);
std::string fmt_str = "[%ls]"; std::string fmt_str = "[%ls]";
auto pos = find_substrs(s, fmt_str); auto pos = find_substrs(s, fmt_str);
...@@ -4295,87 +4551,74 @@ private: ...@@ -4295,87 +4551,74 @@ private:
} else if (pos.size() == 2) { } else if (pos.size() == 2) {
auto p1 = pos[0]; auto p1 = pos[0];
auto p2 = pos[1]; auto p2 = pos[1];
s = s.substr(0, p1) + card.get_spec(player) + s = s.substr(0, p1) + spec +
s.substr(p1 + fmt_str.size(), p2 - p1 - fmt_str.size()) + name + s.substr(p1 + fmt_str.size(), p2 - p1 - fmt_str.size()) + name +
s.substr(p2 + fmt_str.size()); s.substr(p2 + fmt_str.size());
} else { } else {
throw std::runtime_error("Unknown effectyn desc " + throw std::runtime_error("Unknown effectyn desc " +
std::to_string(desc) + " of " + name); std::to_string(desc) + " of " + name);
} }
} else if (desc < 10000u) {
s = get_system_string(desc);
} else {
CardCode code = (desc >> 4) & 0x0fffffff;
uint32_t offset = desc & 0xf;
if (cards_.find(code) != cards_.end()) {
auto &card_ = c_get_card(code);
s = card_.strings_[offset];
if (s.empty()) {
s = "???";
}
} else { } else {
throw std::runtime_error("Unknown effectyn desc " + s = fmt::format(
std::to_string(desc) + " of " + name); "{}({}) ({})", c.name_, spec, c.get_effect_description(code_d, eff_idx));
} }
pl->notify("1: " + s);
pl->notify("2: No");
} }
pl->notify(s);
pl->notify("Please enter y or n."); // TODO: maybe add card info to cancel
} else { legal_actions_.push_back(LegalAction::cancel());
dp_ += 4;
auto c = read_u8();
auto loc = read_u8();
auto seq = read_u8();
auto pos = read_u8();
dp_ += 4;
spec = ls_to_spec(loc, seq, pos, c != player);
}
options_ = {"y " + spec, "n " + spec};
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
YGO_SetResponsei(pduel_, 1); YGO_SetResponsei(pduel_, 1);
} else if (idx == 1) { } else if (idx == 1) {
YGO_SetResponsei(pduel_, 0); YGO_SetResponsei(pduel_, 0);
} else {
throw std::runtime_error("Invalid option");
} }
}; };
} else if (msg_ == MSG_SELECT_OPTION) { } else if (msg_ == MSG_SELECT_OPTION) {
// TODO: add card information
auto player = read_u8(); auto player = read_u8();
auto size = read_u8(); auto size = read_u8();
if (verbose_) { if (verbose_) {
auto pl = players_[player]; players_[player]->notify("Select an option:");
pl->notify("Select an option:"); }
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
auto opt = read_u32(); auto desc = read_u32();
auto [code, eff_idx] = unpack_desc(0, desc);
if (desc == 0) {
show_buffer();
auto s = fmt::format("Unknown desc {} in select_option", desc);
throw std::runtime_error(s);
}
auto la = LegalAction::activate_spec(eff_idx, "");
if (code != 0) {
la.cid_ = c_get_card_id(code);
}
legal_actions_.push_back(la);
if (verbose_) {
std::string s; std::string s;
if (opt > 10000) { if (code == 0) {
CardCode code = opt >> 4; s = get_system_string(eff_idx);
s = c_get_card(code).strings_[opt & 0xf];
} else { } else {
s = get_system_string(opt); Card c = c_get_card(code);
int cmd_idx = legal_actions_.size();
eff_idx -= CARD_EFFECT_OFFSET;
if (eff_idx >= c.strings_.size()) {
throw std::runtime_error(
fmt::format("Unknown effect {} of {}", eff_idx, c.name_));
} }
std::string option = std::to_string(i + 1); auto str = c.strings_[eff_idx];
options_.push_back(option); if (str.empty()) {
pl->notify(option + ": " + s); str = "effect " + std::to_string(eff_idx);
} }
} else { s = fmt::format("{} ({})", c.name_, str);
for (int i = 0; i < size; ++i) {
dp_ += 4;
options_.push_back(std::to_string(i + 1));
} }
players_[player]->notify(std::to_string(i + 1) + ": " + s);
} }
to_play_ = player;
callback_ = [this](int idx) {
if (verbose_) {
players_[to_play_]->notify("You selected option " + options_[idx] +
".");
players_[1 - to_play_]->notify(players_[to_play_]->nickname_ +
" selected option " + options_[idx] +
".");
} }
to_play_ = player;
callback_ = [this](int idx) {
YGO_SetResponsei(pduel_, idx); YGO_SetResponsei(pduel_, idx);
}; };
} else if (msg_ == MSG_SELECT_IDLECMD) { } else if (msg_ == MSG_SELECT_IDLECMD) {
...@@ -4397,90 +4640,97 @@ private: ...@@ -4397,90 +4640,97 @@ private:
pl->notify("Select a card and action to perform."); pl->notify("Select a card and action to perform.");
} }
for (const auto &[code, spec, data] : summonable_) { for (const auto &[code, spec, data] : summonable_) {
std::string option = "s " + spec; legal_actions_.push_back(LegalAction::act_spec(ActionAct::Summon, spec));
options_.push_back(option);
if (verbose_) { if (verbose_) {
const auto &name = c_get_card(code).name_; const auto &name = c_get_card(code).name_;
pl->notify(option + ": Summon " + name + int cmd_idx = legal_actions_.size();
" in face-up attack position."); pl->notify(fmt::format(
"{}: Summon {} in face-up attack position", cmd_idx, name));
} }
} }
offset += summonable_.size(); offset += summonable_.size();
int spsummon_offset = offset; int spsummon_offset = offset;
for (const auto &[code, spec, data] : spsummon_) { for (const auto &[code, spec, data] : spsummon_) {
std::string option = "c " + spec; legal_actions_.push_back(LegalAction::act_spec(ActionAct::SpSummon, spec));
options_.push_back(option);
if (verbose_) { if (verbose_) {
const auto &name = c_get_card(code).name_; const auto &name = c_get_card(code).name_;
pl->notify(option + ": Special summon " + name + "."); int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Special summon {}", cmd_idx, name));
} }
} }
offset += spsummon_.size(); offset += spsummon_.size();
int repos_offset = offset; int repos_offset = offset;
for (const auto &[code, spec, data] : repos_) { for (const auto &[code, spec, data] : repos_) {
std::string option = "r " + spec; legal_actions_.push_back(LegalAction::act_spec(ActionAct::Repo, spec));
options_.push_back(option);
if (verbose_) { if (verbose_) {
const auto &name = c_get_card(code).name_; const auto &name = c_get_card(code).name_;
pl->notify(option + ": Reposition " + name + "."); int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Change position of {}", cmd_idx, name));
} }
} }
offset += repos_.size(); offset += repos_.size();
int mset_offset = offset; int mset_offset = offset;
for (const auto &[code, spec, data] : idle_mset_) { for (const auto &[code, spec, data] : idle_mset_) {
std::string option = "m " + spec; legal_actions_.push_back(LegalAction::act_spec(ActionAct::MSet, spec));
options_.push_back(option);
if (verbose_) { if (verbose_) {
const auto &name = c_get_card(code).name_; const auto &name = c_get_card(code).name_;
pl->notify(option + ": Summon " + name + int cmd_idx = legal_actions_.size();
" in face-down defense position."); pl->notify(fmt::format(
"{}: Summon {} in face-down defense position", cmd_idx, name));
} }
} }
offset += idle_mset_.size(); offset += idle_mset_.size();
int set_offset = offset; int set_offset = offset;
for (const auto &[code, spec, data] : idle_set_) { for (const auto &[code, spec, data] : idle_set_) {
std::string option = "t " + spec; legal_actions_.push_back(LegalAction::act_spec(ActionAct::Set, spec));
options_.push_back(option);
if (verbose_) { if (verbose_) {
const auto &name = c_get_card(code).name_; const auto &name = c_get_card(code).name_;
pl->notify(option + ": Set " + name + "."); int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Set {}", cmd_idx, name));
} }
} }
offset += idle_set_.size(); offset += idle_set_.size();
int activate_offset = offset; int activate_offset = offset;
ankerl::unordered_dense::map<std::string, int> idle_activate_count; for (const auto &[code_t, spec, desc] : idle_activate_) {
for (const auto &[code, spec, data] : idle_activate_) { CardCode code = code_t;
idle_activate_count[spec] += 1; if(code & 0x80000000) {
} code &= 0x7fffffff;
ankerl::unordered_dense::map<std::string, int> activate_count; }
for (const auto &[code, spec, data] : idle_activate_) { auto [code_d, eff_idx] = unpack_desc(code, desc);
// TODO: use effect description to indicate which effect to activate if (desc == 0) {
std::string option = "v " + spec; code_d = code;
int count = idle_activate_count[spec]; }
activate_count[spec]++; auto la = LegalAction::activate_spec(eff_idx, spec);
if (count > 1) { if (code_d != 0) {
option.push_back('a' + activate_count[spec] - 1); la.cid_ = c_get_card_id(code_d);
} }
options_.push_back(option); legal_actions_.push_back(la);
if (verbose_) { if (verbose_) {
pl->notify(option + ": " + auto c = c_get_card(code);
c_get_card(code).get_effect_description(data)); int cmd_idx = legal_actions_.size();
std::string s = fmt::format(
"{}: Activate {}({}) ({})",
cmd_idx, c.name_, spec, c.get_effect_description(code_d, eff_idx));
pl->notify(s);
} }
} }
if (to_bp_) { if (to_bp_) {
std::string cmd = "b"; legal_actions_.push_back(LegalAction::phase(ActionPhase::Battle));
options_.push_back(cmd);
if (verbose_) { if (verbose_) {
pl->notify(cmd + ": Enter the battle phase."); int cmd_idx = legal_actions_.size();
pl->notify(fmt::format("{}: Enter the battle phase.", cmd_idx));
} }
} }
if (to_ep_) { if (to_ep_) {
if (!to_bp_) { if (!to_bp_) {
std::string cmd = "e"; legal_actions_.push_back(LegalAction::phase(ActionPhase::End));
options_.push_back(cmd);
if (verbose_) { if (verbose_) {
pl->notify(cmd + ": End phase."); int cmd_idx = legal_actions_.size();
pl->notify(fmt::format("{}: End phase.", cmd_idx));
} }
} }
} }
...@@ -4488,104 +4738,90 @@ private: ...@@ -4488,104 +4738,90 @@ private:
to_play_ = player; to_play_ = player;
callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset, callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset,
activate_offset](int idx) { activate_offset](int idx) {
const auto &option = options_[idx]; const auto &action = legal_actions_[idx];
char cmd = option[0]; if (action.phase_ == ActionPhase::Battle) {
if (cmd == 'b') {
YGO_SetResponsei(pduel_, 6); YGO_SetResponsei(pduel_, 6);
} else if (cmd == 'e') { } else if (action.phase_ == ActionPhase::End) {
YGO_SetResponsei(pduel_, 7); YGO_SetResponsei(pduel_, 7);
} else { } else {
auto spec = option.substr(2); auto act = action.act_;
if (cmd == 's') { if (act == ActionAct::Summon) {
uint32_t idx_ = idx; uint32_t idx_ = idx;
YGO_SetResponsei(pduel_, idx_ << 16); YGO_SetResponsei(pduel_, idx_ << 16);
} else if (cmd == 'c') { } else if (act == ActionAct::SpSummon) {
uint32_t idx_ = idx - spsummon_offset; uint32_t idx_ = idx - spsummon_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 1); YGO_SetResponsei(pduel_, (idx_ << 16) + 1);
} else if (cmd == 'r') { } else if (act == ActionAct::Repo) {
uint32_t idx_ = idx - repos_offset; uint32_t idx_ = idx - repos_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 2); YGO_SetResponsei(pduel_, (idx_ << 16) + 2);
} else if (cmd == 'm') { } else if (act == ActionAct::MSet) {
uint32_t idx_ = idx - mset_offset; uint32_t idx_ = idx - mset_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 3); YGO_SetResponsei(pduel_, (idx_ << 16) + 3);
} else if (cmd == 't') { } else if (act == ActionAct::Set) {
uint32_t idx_ = idx - set_offset; uint32_t idx_ = idx - set_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 4); YGO_SetResponsei(pduel_, (idx_ << 16) + 4);
} else if (cmd == 'v') { } else if (act == ActionAct::Activate) {
uint32_t idx_ = idx - activate_offset; uint32_t idx_ = idx - activate_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 5); YGO_SetResponsei(pduel_, (idx_ << 16) + 5);
} else {
throw std::runtime_error("Invalid option: " + option);
} }
} }
}; };
} else if (msg_ == MSG_SELECT_PLACE) { } else if (msg_ == MSG_SELECT_PLACE || msg_ == MSG_SELECT_DISFIELD) {
// TODO(1): add card informaton to select place
auto player = read_u8(); auto player = read_u8();
auto count = read_u8(); auto count = read_u8();
if (count == 0) { if (count == 0) {
count = 1; count = 1;
} }
auto flag = read_u32(); if (count != 1) {
options_ = flag_to_usable_cardspecs(flag); auto s = fmt::format("Select place count {} not implemented for {}",
if (verbose_) { count, msg_ == MSG_SELECT_PLACE ? "place" : "disfield");
std::string specs_str = options_[0]; throw std::runtime_error(s);
for (int i = 1; i < options_.size(); ++i) {
specs_str += ", " + options_[i];
}
if (count == 1) {
players_[player]->notify("Select place for card, one of " +
specs_str + ".");
} else {
players_[player]->notify("Select " + std::to_string(count) +
" places for card, from " + specs_str + ".");
}
}
to_play_ = player;
callback_ = [this, player](int idx) {
int y = player + 1;
std::string spec = options_[idx];
auto plr = player;
if (spec[0] == 'o') {
plr = 1 - player;
spec = spec.substr(1);
}
auto [loc, seq, pos] = spec_to_ls(spec);
resp_buf_[0] = plr;
resp_buf_[1] = loc;
resp_buf_[2] = seq;
YGO_SetResponseb(pduel_, resp_buf_);
};
} else if (msg_ == MSG_SELECT_DISFIELD) {
auto player = read_u8();
auto count = read_u8();
if (count == 0) {
count = 1;
} }
auto flag = read_u32(); auto flag = read_u32();
options_ = flag_to_usable_cardspecs(flag); auto places = flag_to_usable_places(flag);
if (verbose_) { if (verbose_) {
std::string specs_str = options_[0]; auto place_s = msg_ == MSG_SELECT_PLACE ? "place" : "disfield";
for (int i = 1; i < options_.size(); ++i) { auto s = fmt::format("Select {} for card, one of:", place_s);
specs_str += ", " + options_[i]; players_[player]->notify(s);
} }
if (count == 1) { for (int i = 0; i < places.size(); ++i) {
players_[player]->notify("Select place for card, one of " + legal_actions_.push_back(LegalAction::place(places[i]));
specs_str + "."); if (verbose_) {
} else { auto s = fmt::format("{}: {}", i + 1, action_place_to_string(places[i]));
throw std::runtime_error("Select disfield count " + players_[player]->notify(s);
std::to_string(count) + " not implemented");
} }
} }
to_play_ = player; to_play_ = player;
callback_ = [this, player](int idx) { callback_ = [this, player](int idx) {
int y = player + 1; auto place = legal_actions_[idx].place_;
std::string spec = options_[idx]; int i = static_cast<int>(place);
auto plr = player; uint8_t plr = player;
if (spec[0] == 'o') { uint8_t loc;
uint8_t seq;
if (
i >= static_cast<int>(ActionPlace::MZone1) &&
i <= static_cast<int>(ActionPlace::MZone7)) {
loc = LOCATION_MZONE;
seq = i - static_cast<int>(ActionPlace::MZone1);
} else if (
i >= static_cast<int>(ActionPlace::SZone1) &&
i <= static_cast<int>(ActionPlace::SZone8)) {
loc = LOCATION_SZONE;
seq = i - static_cast<int>(ActionPlace::SZone1);
} else if (
i >= static_cast<int>(ActionPlace::OpMZone1) &&
i <= static_cast<int>(ActionPlace::OpMZone7)) {
plr = 1 - player;
loc = LOCATION_MZONE;
seq = i - static_cast<int>(ActionPlace::OpMZone1);
} else if (
i >= static_cast<int>(ActionPlace::OpSZone1) &&
i <= static_cast<int>(ActionPlace::OpSZone8)) {
plr = 1 - player; plr = 1 - player;
spec = spec.substr(1); loc = LOCATION_SZONE;
seq = i - static_cast<int>(ActionPlace::OpSZone1);
} }
auto [loc, seq, pos] = spec_to_ls(spec);
resp_buf_[0] = plr; resp_buf_[0] = plr;
resp_buf_[1] = loc; resp_buf_[1] = loc;
resp_buf_[2] = seq; resp_buf_[2] = seq;
...@@ -4620,7 +4856,7 @@ private: ...@@ -4620,7 +4856,7 @@ private:
// auto spec = ls_to_spec(loc, seq, 0, controller != player); // auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec); // options_.push_back(spec);
} }
// TODO: implement action // TODO(2): implement action
n_counters_ = count; n_counters_ = count;
uint16_t resp1 = static_cast<uint16_t>(std::min(counter_count, counters[0])); uint16_t resp1 = static_cast<uint16_t>(std::min(counter_count, counters[0]));
memcpy(resp_buf_, &resp1, 2); memcpy(resp_buf_, &resp1, 2);
...@@ -4644,20 +4880,16 @@ private: ...@@ -4644,20 +4880,16 @@ private:
" not implemented for announce number"); " not implemented for announce number");
} }
numbers.push_back(number); numbers.push_back(number);
options_.push_back(std::string(1, '0' + number)); legal_actions_.push_back(LegalAction::number(number));
} }
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto pl = players_[player];
std::string str = "Select a number, one of: ["; std::string str = "Select a number, one of:";
pl->notify(str);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
str += std::to_string(numbers[i]); pl->notify(fmt::format("{}: {}", i + 1, numbers[i]));
if (i < count - 1) {
str += ", ";
} }
} }
str += "]";
pl->notify(str);
}
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
YGO_SetResponsei(pduel_, idx); YGO_SetResponsei(pduel_, idx);
...@@ -4675,7 +4907,7 @@ private: ...@@ -4675,7 +4907,7 @@ private:
attrs.push_back(i + 1); attrs.push_back(i + 1);
} }
} }
// TODO(2): implement action
if (count != 1) { if (count != 1) {
throw std::runtime_error("Announce attrib count " + throw std::runtime_error("Announce attrib count " +
std::to_string(count) + " not implemented"); std::to_string(count) + " not implemented");
...@@ -4686,40 +4918,28 @@ private: ...@@ -4686,40 +4918,28 @@ private:
pl->notify("Select " + std::to_string(count) + pl->notify("Select " + std::to_string(count) +
" attributes separated by spaces:"); " attributes separated by spaces:");
for (int i = 0; i < attrs.size(); i++) { for (int i = 0; i < attrs.size(); i++) {
pl->notify(std::to_string(attrs[i]) + ": " + pl->notify(fmt::format("{}: {}", i + 1, attribute_to_string(1 << (attrs[i] - 1))));
attribute_to_string(1 << (attrs[i] - 1)));
} }
} }
auto combs = combinations(attrs.size(), count); // auto combs = combinations(attrs.size(), count);
for (const auto &comb : combs) { for (int i = 0; i < attrs.size(); i++) {
std::string option = ""; legal_actions_.push_back(LegalAction::attribute(1 << (attrs[i] - 1)));
for (int j = 0; j < count; ++j) {
option += std::to_string(attrs[comb[j]]);
if (j < count - 1) {
option += " ";
}
}
options_.push_back(option);
} }
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
const auto &option = options_[idx]; const auto &action = legal_actions_[idx];
uint32_t resp = 0; uint32_t resp = 0;
int i = 0; resp |= action.attribute_;
while (i < option.size()) {
resp |= 1 << (option[i] - '1');
i += 2;
}
YGO_SetResponsei(pduel_, resp); YGO_SetResponsei(pduel_, resp);
}; };
} else if (msg_ == MSG_SELECT_POSITION) { } else if (msg_ == MSG_SELECT_POSITION) {
// TODO: add card as feature
auto player = read_u8(); auto player = read_u8();
auto code = read_u32(); auto code = read_u32();
auto valid_pos = read_u8(); auto valid_pos = read_u8();
CardId cid = c_get_card_id(code);
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto pl = players_[player];
...@@ -4727,25 +4947,25 @@ private: ...@@ -4727,25 +4947,25 @@ private:
pl->notify("Select position for " + card.name_ + ":"); pl->notify("Select position for " + card.name_ + ":");
} }
std::vector<uint8_t> positions;
int i = 1;
for (auto pos : {POS_FACEUP_ATTACK, POS_FACEDOWN_ATTACK, for (auto pos : {POS_FACEUP_ATTACK, POS_FACEDOWN_ATTACK,
POS_FACEUP_DEFENSE, POS_FACEDOWN_DEFENSE}) { POS_FACEUP_DEFENSE, POS_FACEDOWN_DEFENSE}) {
if (valid_pos & pos) { if (valid_pos & pos) {
positions.push_back(pos); LegalAction la;
options_.push_back(std::to_string(i)); la.cid_ = cid;
la.position_ = pos;
legal_actions_.push_back(la);
int cmd_idx = legal_actions_.size();
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto pl = players_[player];
pl->notify(fmt::format("{}: {}", i, position_to_string(pos))); pl->notify(fmt::format("{}: {}", cmd_idx, position_to_string(pos)));
} }
} }
i++;
} }
to_play_ = player; to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
uint8_t pos = options_[idx][0] - '1'; uint8_t pos = legal_actions_[idx].position_;
YGO_SetResponsei(pduel_, 1 << pos); YGO_SetResponsei(pduel_, pos);
}; };
} else { } else {
show_deck(0); show_deck(0);
...@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>; ...@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
} // namespace ygopro } // namespace ygopro
template <>
struct fmt::formatter<ygopro::LegalAction>: formatter<string_view> {
// Format the LegalAction object
template <typename FormatContext>
auto format(const ygopro::LegalAction& action, FormatContext& ctx) const {
std::stringstream ss;
ss << "{";
if (!action.spec_.empty()) {
ss << "spec='" << action.spec_ << "', ";
}
if (action.cid_ != 0) {
ss << "cid=" << action.cid_ << ", ";
}
if (action.act_ != ygopro::ActionAct::None) {
ss << "act=" << ygopro::action_act_to_string(action.act_) << ", ";
}
if (action.phase_ != ygopro::ActionPhase::None) {
ss << "phase=" << ygopro::action_phase_to_string(action.phase_) << ", ";
}
if (action.finish_) {
ss << "finish=true, ";
}
if (action.position_ != 0) {
ss << "position=" << ygopro::position_to_string(action.position_) << ", ";
}
if (action.effect_ != -1) {
ss << "effect=" << action.effect_ << ", ";
}
if (action.number_ != 0) {
ss << "number=" << int(action.number_) << ", ";
}
if (action.place_ != ygopro::ActionPlace::None) {
ss << "place=" << ygopro::action_place_to_string(action.place_) << ", ";
}
if (action.attribute_ != 0) {
ss << "attribute=" << ygopro::attribute_to_string(action.attribute_) << ", ";
}
std::string s = ss.str();
if (s.back() == ' ') {
s.pop_back();
s.pop_back();
}
s.push_back('}');
return format_to(ctx.out(), "{}", s);
}
};
#endif // YGOENV_YGOPRO_YGOPRO_H_ #endif // YGOENV_YGOPRO_YGOPRO_H_
from ygoenv.python.api import py_env
from .ygopro0_ygoenv import (
_YGOPro0EnvPool,
_YGOPro0EnvSpec,
init_module,
)
(
YGOPro0EnvSpec,
YGOPro0DMEnvPool,
YGOPro0GymEnvPool,
YGOPro0GymnasiumEnvPool,
) = py_env(_YGOPro0EnvSpec, _YGOPro0EnvPool)
__all__ = [
"YGOPro0EnvSpec",
"YGOPro0DMEnvPool",
"YGOPro0GymEnvPool",
"YGOPro0GymnasiumEnvPool",
]
from ygoenv.registration import register
register(
task_id="YGOPro-v0",
import_path="ygoenv.ygopro0",
spec_cls="YGOPro0EnvSpec",
dm_cls="YGOPro0DMEnvPool",
gym_cls="YGOPro0GymEnvPool",
gymnasium_cls="YGOPro0GymnasiumEnvPool",
)
#include "ygoenv/ygopro0/ygopro.h"
#include "ygoenv/core/py_envpool.h"
using YGOPro0EnvSpec = PyEnvSpec<ygopro0::YGOProEnvSpec>;
using YGOPro0EnvPool = PyEnvPool<ygopro0::YGOProEnvPool>;
PYBIND11_MODULE(ygopro0_ygoenv, m) {
REGISTER(m, YGOPro0EnvSpec, YGOPro0EnvPool)
m.def("init_module", &ygopro0::init_module);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment