Commit 04e61b91 authored by sbl1996@126.com's avatar sbl1996@126.com

Add YGOPro-v1

parent f6139c17
# Action
## Types
- Set + card
- Reposition + card
- Special summon + card
- Summon Face-up Attack + card
- Summon Face-down Defense + card
- Attack + card
- DirectAttack + card
- Activate + card + effect
- Cancel
- Switch + phase
- SelectPosition + card + position
- AnnounceNumber + card + effect + number
- SelectPlace + card + place
- AnnounceAttrib + card + effect + attrib
## Effect
### MSG_SELECT_BATTLECMD | MSG_SELECT_IDLECMD | MSG_SELECT_CHAIN | MSG_SELECT_EFFECTYN
- desc == 0: default effect of card
- desc < LIMIT: system string
- desc > LIMIT: card + effect
### MSG_SELECT_OPTION | MSG_SELECT_YESNO
- desc == 0: error
- desc < LIMIT: system string
- desc > LIMIT: card + effect
......@@ -49,50 +49,42 @@ The card id is the index of the card code in `code_list.txt`.
## Legal Actions
- 0,1: spec index, uint16 -> 2 uint8
- 2: msg, discrete, 0: N/A, 1+: same as msg2str (15)
- 3: act, discrete (11)
- 0: spec index
- 1,2: code, uint16 -> 2 uint8
- 3: msg, discrete, 0: N/A, 1+: same as msg2str (15)
- 4: act, discrete (11)
- N/A
- t: Set
- r: Reposition
- c: Special Summon
- s: Summon Face-up Attack
- m: Summon Face-down Defense
- a: Attack
- v: Activate
- v2: Activate the second effect
- v3: Activate the third effect
- v4: Activate the fourth effect
- 4: yes/no, discrete (3)
- Set
- Reposition
- Special Summon
- Summon Face-up Attack
- Summon Face-down Defense
- Attack
- DirectAttack
- Activate
- Cancel
- 5: finish, discrete (2)
- N/A
- Yes
- No
- 5: phase, discrete (4)
- Finish
- 6: effect, discrete, 0: N/A
- 7: phase, discrete (4)
- N/A
- Battle (b)
- Main Phase 2 (m)
- End Phase (e)
- 6: cancel, discrete (2)
- N/A
- Cancel
- 7: finish, discrete (2)
- N/A
- Finish
- 8: position, discrete, 0: N/A, same as position2str
- 9: option, discrete, 0: N/A
- 10: number, discrete, 0: N/A
- 11: place, discrete
- 9: number, discrete, 0: N/A
- 10: place, discrete
- 0: N/A
- 1-7: m
- 8-15: s
- 16-22: om
- 23-30: os
- 12: attribute, discrete, 0: N/A, same as attribute2id
- 11: attribute, discrete, 0: N/A, same as attribute2id
## History Actions
- 0,1: card id, uint16 -> 2 uint8
- 2-12 same as legal actions
- 13: player, discrete, 0: me, 1: oppo
- 14: turn, discrete, trunc to 3
- 2-11 same as legal actions
- 12: turn, discrete, trunc to 3
- 13: phase, discrete (10)
......@@ -18,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
@dataclass
......
......@@ -135,7 +135,7 @@ if __name__ == "__main__":
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import RNNAgent
from ygoai.rl.jax.agent import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
......@@ -168,7 +168,6 @@ if __name__ == "__main__":
obs, infos = envs.reset()
print(obs)
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
......
......@@ -8,6 +8,24 @@ add_requires(
"sqlitecpp 3.2.1")
target("ygopro0_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro0/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core")
set_languages("c++17")
if is_mode("release") then
set_policy("build.optimization.lto", true)
add_cxxflags("-march=native")
end
add_includedirs("ygoenv")
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro0"
os.cp(target:targetfile(), install_target)
print("Copy target to " .. install_target)
end)
target("ygopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp")
......@@ -25,7 +43,6 @@ target("ygopro_ygoenv")
print("Copy target to " .. install_target)
end)
target("edopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/edopro/*.cpp")
......
from typing import Tuple, Union, Optional
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
default_embed_init = nn.initializers.uniform(scale=0.001)
......@@ -14,11 +16,18 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
......@@ -26,7 +35,6 @@ class ActionEncoder(nn.Module):
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
......@@ -38,18 +46,165 @@ class ActionEncoder(nn.Module):
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1)
xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
return xs
class ActionEncoderV1(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(10, c // div)(x[:, :, 1])
x_a_finish = embed(3, c // div // 2)(x[:, :, 2])
x_a_effect = embed(256, c // div * 2)(x[:, :, 3])
x_a_phase = embed(4, c // div // 2)(x[:, :, 4])
x_a_position = embed(9, c // div)(x[:, :, 5])
x_a_number = embed(13, c // div // 2)(x[:, :, 6])
x_a_place = embed(31, c // div)(x[:, :, 7])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 8])
xs = [x_a_msg, x_a_act, x_a_finish, x_a_effect, x_a_phase,
x_a_position, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
version: int = 0
@nn.compact
def __call__(self, x_id, x):
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x_loc = x1[:, :, 0]
x_seq = x1[:, :, 1]
if self.version == 0:
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
f_loc = layer_norm()(embed(9, c)(x_loc))
f_seq = layer_norm()(embed(76, c)(x_seq))
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
if self.version == 0:
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq)
x_cards = jnp.concatenate([
f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * x_id
f_cards = layer_norm()(x_cards)
return f_cards, c_mask
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = layer_norm()(x)
return x
class Encoder(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
noam: bool = False
version: int = 0
@nn.compact
def __call__(self, x):
......@@ -62,154 +217,182 @@ class Encoder(nn.Module):
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
fc_layer = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = MLP((c // 8,), last_lin=False, dtype=jnp.float32, param_dtype=self.param_dtype)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32)
x_cards_2 = x_cards[:, :, 12:].astype(jnp.float32)
x_id = decode_id(x_cards_1[:, :, :2])
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
x_id = MLP(
(c, c // 4), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
x_seq = x_cards_1[:, :, 3]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x_cards_1[:, :, 4])
x_position = embed(9, c // 16)(x_cards_1[:, :, 5])
x_overley = embed(2, c // 16)(x_cards_1[:, :, 6])
x_attribute = embed(8, c // 16)(x_cards_1[:, :, 7])
x_race = embed(27, c // 16)(x_cards_1[:, :, 8])
x_level = embed(14, c // 16)(x_cards_1[:, :, 9])
x_counter = embed(16, c // 16)(x_cards_1[:, :, 10])
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11])
x_atk = num_transform(x_cards_2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
# Cards
f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
x_feat = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_feat = layer_norm()(x_feat)
f_cards = jnp.concatenate([x_id, x_feat], axis=-1)
f_cards = f_cards + f_loc + f_seq
num_heads = max(2, c // 128)
for _ in range(self.num_card_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
for _ in range(self.num_layers):
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
if self.version == 0:
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., 1:3])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c, dtype=jnp.float32)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
x_h_a_phase = embed(12, c // 2)(x_h_actions[:, :, 13])
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c, dtype=self.dtype)(x_h_a_feats)
if self.noam:
f_h_actions = LlamaEncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype,
rope=True, n_positions=64)(x_h_a_feats, src_key_padding_mask=h_mask)
else:
x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(jnp.float32)
x_g_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0])
x_g_phase = embed(11, c // 8)(x_global_2[:, 1])
x_g_if_first = embed(2, c // 8)(x_global_2[:, 2])
x_g_is_my_turn = embed(2, c // 8)(x_global_2[:, 3])
x_global_3 = x_global[:, 8:22].astype(jnp.int32)
x_g_cs = count_embed(x_global_3).reshape((batch_size, -1))
x_g_my_hand_c = hand_count_embed(x_global_3[:, 1])
x_g_op_hand_c = hand_count_embed(x_global_3[:, 8])
x_global = jnp.concatenate([
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
f_cards = f_cards + jnp.expand_dims(f_global, 1)
x_actions = x_actions.astype(jnp.int32)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = f_a_cards + fc_layer(c)(layer_norm()(f_a_cards))
x_a_feats = action_encoder(x_actions[..., 2:])
f_actions = f_a_cards + layer_norm()(x_a_feats)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, f_cards,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_'].astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
f_h_actions = layer_norm()(x_h_id) + layer_norm()(x_h_a_feats)
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_action_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(
f_actions, f_h_actions,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask)
f_actions = layer_norm()(f_actions)
f_s_cards_global = f_cards.mean(axis=1)
c_mask = 1 - a_mask[:, :, None].astype(f_actions.dtype)
f_s_actions_ha = (f_actions * c_mask).sum(axis=1) / c_mask.sum(axis=1)
f_state = jnp.concatenate([f_s_cards_global, f_s_actions_ha], axis=-1)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
if self.version == 0:
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
spec_index = x_actions[..., 0]
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
x_a_id = decode_id(x_actions[..., 1:3])
x_a_id = id_embed(x_a_id)
if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c, dtype=jnp.float32)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c, dtype=self.dtype)(f_actions)
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if self.use_history:
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
......@@ -219,54 +402,199 @@ class Actor(nn.Module):
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_actions, mask):
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128)
f_actions = EncoderLayer(
num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = mlp((c // 4, 1), use_bias=True)(f_actions)
logits = logits[..., 0]
f_actions = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: int = 128
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(1.0))
x = MLP((c // 2, 1), use_bias=True)(f_state)
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
class PPOAgent(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main)
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
version: int = 0
"""the version of the environment and the agent"""
class RNNAgent(nn.Module):
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
version: int = 0
@nn.compact
def __call__(self, x):
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
encoder = Encoder(
channels=self.channels,
num_card_layers=self.num_card_layers,
num_action_layers=self.num_action_layers,
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
version=self.version,
)
actor = Actor(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
critic = Critic(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_actions, mask)
value = critic(f_state)
return logits, value, valid
if self.rnn_type in ['lstm', 'none']:
rnn_layer = nn.OptimizedLSTMCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type is None:
rnn_layer = None
if rnn_layer is None:
f_state_r = f_state
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x_id, x):
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
return f_cards, c_mask
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = layer_norm()(x)
return x
class Encoder(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
noam: bool = False
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
# Cards
f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128)
f_actions = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True):
if switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main)
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
rstate, f_state = scan(rnn_layer, rstate, f_state, done, switch_or_main)
return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
class RNNAgent(nn.Module):
num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
switch: bool = True
freeze_id: bool = False
use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.num_channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
)
f_actions, f_state, mask, valid = encoder(x)
if self.rnn_type in ['lstm', 'none']:
rnn_layer = nn.OptimizedLSTMCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type == 'gru':
rnn_layer = nn.GRUCell(
self.rnn_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
elif self.rnn_type is None:
rnn_layer = None
if rnn_layer is None:
f_state_r = f_state
elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else:
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
if self.rnn_type in ['lstm', 'none']:
return (
np.zeros((batch_size, self.rnn_channels)),
np.zeros((batch_size, self.rnn_channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.rnn_channels))
else:
return None
\ No newline at end of file
......@@ -41,7 +41,12 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
raise FileNotFoundError(f"Token deck not found: {token_deck}")
decks["_tokens"] = str(token_deck)
if 'YGOPro' in env_id:
from ygoenv.ygopro import init_module
if env_id == 'YGOPro-v1':
from ygoenv.ygopro import init_module
elif env_id == 'YGOPro-v0':
from ygoenv.ygopro0 import init_module
else:
raise ValueError(f"Unknown YGOPro environment: {env_id}")
elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks)
......
......@@ -18,13 +18,16 @@ try:
except ImportError:
pass
try:
import ygoenv.ygopro0.registration # noqa: F401
except ImportError:
pass
try:
import ygoenv.edopro.registration # noqa: F401
except ImportError:
pass
try:
import ygoenv.dummy.registration # noqa: F401
except ImportError:
......
from ygoenv.registration import register
register(
task_id="YGOPro-v0",
task_id="YGOPro-v1",
import_path="ygoenv.ygopro",
spec_cls="YGOProEnvSpec",
dm_cls="YGOProDMEnvPool",
......
......@@ -23,7 +23,7 @@
#include <ankerl/unordered_dense.h>
#include <unordered_set>
#include "BS_thread_pool.h"
#include "ygoenv/core/BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
......@@ -305,13 +305,85 @@ static std::string msg_to_string(int msg) {
}
// system string
static const ankerl::unordered_dense::map<int, std::string> system_strings = {
static const std::map<int, std::string> system_strings = {
// announce type
{1050, "Monster"},
{1051, "Spell"},
{1052, "Trap"},
{1054, "Normal"},
{1055, "Effect"},
{1056, "Fusion"},
{1057, "Ritual"},
{1058, "Trap Monsters"},
{1059, "Spirit"},
{1060, "Union"},
{1061, "Gemini"},
{1062, "Tuner"},
{1063, "Synchro"},
{1064, "Token"},
{1066, "Quick-Play"},
{1067, "Continuous"},
{1068, "Equip"},
{1069, "Field"},
{1070, "Counter"},
{1071, "Flip"},
{1072, "Toon"},
{1073, "Xyz"},
{1074, "Pendulum"},
{1075, "Special Summon"},
{1076, "Link"},
{1080, "(N/A)"},
{1081, "Extra Monster Zone"},
// announce type end
// actions
{1150, "Activate"},
{1151, "Normal Summon"},
{1152, "Special Summon"},
{1153, "Set"},
{1154, "Flip Summon"},
{1155, "To Defense"},
{1156, "To Attack"},
{1157, "Attack"},
{1158, "View"},
{1159, "S/T Set"},
{1160, "Put in Pendulum Zone"},
{1161, "Do Effect"},
{1162, "Reset Effect"},
{1163, "Pendulum Summon"},
{1164, "Synchro Summon"},
{1165, "Xyz Summon"},
{1166, "Link Summon"},
{1167, "Tribute Summon"},
{1168, "Ritual Summon"},
{1169, "Fusion Summon"},
{1190, "Add to hand"},
{1191, "Send to GY"},
{1192, "Banish"},
{1193, "Return to Deck"},
// actions end
{1, "Normal Summon"},
{30, "Replay rules apply. Continue this attack?"},
{31, "Attack directly with this monster?"},
{80, "Start Step of the Battle Phase."},
{81, "During the End Phase."},
{90, "Conduct this Normal Summon without Tributing?"},
{91, "Use additional Summon?"},
{92, "Tribute your opponent's monster?"},
{93, "Continue selecting Materials?"},
{94, "Activate this card's effect now?"},
{95, "Use the effect of [%ls]?"},
{96, "Use the effect of [%ls] to avoid destruction?"},
{97, "Place [%ls] to a Spell & Trap Zone?"},
{98, "Tribute a monster(s) your opponent controls?"},
{200, "From [%ls], activate [%ls]?"},
{203, "Chain another card or effect?"},
{210, "Continue selecting?"},
{218, "Pay LP by Effect of [%ls], instead?"},
{219, "Detach Xyz material by Effect of [%ls], instead?"},
{220, "Remove Counter(s) by Effect of [%ls], instead?"},
{221, "On [%ls], Activate Trigger Effect of [%ls]?"},
{222, "Activate Trigger Effect?"},
{221, "On [%ls], Activate Trigger Effect of [%ls]?"},
{1190, "Add to hand"},
{1192, "Banish"},
{1621, "Attack Negated"},
{1622, "[%ls] Missed timing"}
};
......@@ -321,7 +393,9 @@ static std::string get_system_string(int desc) {
if (it != system_strings.end()) {
return it->second;
}
return "system string " + std::to_string(desc);
throw std::runtime_error(
fmt::format("Cannot find system string: {}", desc));
// return "system string " + std::to_string(desc);
}
static std::string ltrim(std::string s) {
......@@ -331,24 +405,6 @@ static std::string ltrim(std::string s) {
return s;
}
inline std::vector<std::string> flag_to_usable_cardspecs(uint32_t flag,
bool reverse = false) {
std::string zone_names[4] = {"m", "s", "om", "os"};
std::vector<std::string> specs;
for (int j = 0; j < 4; j++) {
uint32_t value = (flag >> (j * 8)) & 0xff;
for (int i = 0; i < 8; i++) {
bool avail = (value & (1 << i)) == 0;
if (reverse) {
avail = !avail;
}
if (avail) {
specs.push_back(zone_names[j] + std::to_string(i + 1));
}
}
}
return specs;
}
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
std::string spec;
......@@ -402,7 +458,8 @@ spec_to_ls(const std::string spec) {
loc = LOCATION_DECK;
offset = 0;
} else {
throw std::runtime_error("Invalid location");
std::string s = fmt::format("Invalid spec {}", spec);
throw std::runtime_error(s);
}
int end = offset;
while (end < spec.size() && std::isdigit(spec[end])) {
......@@ -415,33 +472,19 @@ spec_to_ls(const std::string spec) {
return {loc, seq, pos};
}
inline uint32_t ls_to_spec_code(uint8_t loc, uint8_t seq, uint8_t pos,
bool opponent) {
uint32_t c = opponent ? 1 : 0;
c |= (loc << 8);
c |= (seq << 16);
c |= (pos << 24);
return c;
}
inline uint32_t spec_to_code(const std::string &spec) {
inline std::tuple<uint8_t, uint8_t, uint8_t, uint8_t>
spec_to_ls(uint8_t player, const std::string spec) {
uint8_t controller = player;
int offset = 0;
bool opponent = false;
if (spec[0] == 'o') {
opponent = true;
controller = 1 - player;
offset++;
}
auto [loc, seq, pos] = spec_to_ls(spec.substr(offset));
return ls_to_spec_code(loc, seq, pos, opponent);
return {controller, loc, seq, pos};
}
inline std::string code_to_spec(uint32_t spec_code) {
uint8_t loc = (spec_code >> 8) & 0xff;
uint8_t seq = (spec_code >> 16) & 0xff;
uint8_t pos = (spec_code >> 24) & 0xff;
bool opponent = (spec_code & 0xff) == 1;
return ls_to_spec(loc, seq, pos, opponent);
}
static std::tuple<std::vector<uint32>, std::vector<uint32>, std::vector<uint32>> read_decks(const std::string &fp) {
std::ifstream file(fp);
......@@ -567,6 +610,11 @@ inline std::string name(decltype(x_map)::key_type x) { \
return "unknown"; \
}
static const ankerl::unordered_dense::map<int, uint8_t> system_string2id =
make_ids(system_strings, 16);
DEFINE_X_TO_ID_FUN(system_string_to_id, system_string2id)
static const std::map<uint8_t, std::string> location2str = {
{LOCATION_DECK, "Deck"},
{LOCATION_HAND, "Hand"},
......@@ -722,29 +770,152 @@ static const ankerl::unordered_dense::map<int, uint8_t> msg2id =
DEFINE_X_TO_ID_FUN(msg_to_id, msg2id)
static const ankerl::unordered_dense::map<char, uint8_t> cmd_act2id =
make_ids({'t', 'r', 'c', 's', 'm', 'a', 'v'}, 1);
DEFINE_X_TO_ID_FUN(cmd_act_to_id, cmd_act2id)
enum class ActionAct {
None,
Set,
Repo,
SpSummon,
Summon,
MSet,
Attack,
DirectAttack,
Activate,
Cancel,
};
inline std::string action_act_to_string(ActionAct act) {
switch (act) {
case ActionAct::None:
return "None";
case ActionAct::Set:
return "Set";
case ActionAct::Repo:
return "Repo";
case ActionAct::SpSummon:
return "SpSummon";
case ActionAct::Summon:
return "Summon";
case ActionAct::MSet:
return "MSet";
case ActionAct::Attack:
return "Attack";
case ActionAct::DirectAttack:
return "DirectAttack";
case ActionAct::Activate:
return "Activate";
case ActionAct::Cancel:
return "Cancel";
default:
return "Unknown";
}
}
static const ankerl::unordered_dense::map<char, uint8_t> cmd_phase2id =
make_ids(std::vector<char>({'b', 'm', 'e'}), 1);
DEFINE_X_TO_ID_FUN(cmd_phase_to_id, cmd_phase2id)
enum class ActionPhase {
None,
Battle,
Main2,
End,
};
inline std::string action_phase_to_string(ActionPhase phase) {
switch (phase) {
case ActionPhase::None:
return "None";
case ActionPhase::Battle:
return "Battle";
case ActionPhase::Main2:
return "Main2";
case ActionPhase::End:
return "End";
default:
return "Unknown";
}
}
static const ankerl::unordered_dense::map<char, uint8_t> cmd_yesno2id =
make_ids(std::vector<char>({'y', 'n'}), 1);
DEFINE_X_TO_ID_FUN(cmd_yesno_to_id, cmd_yesno2id)
enum class ActionPlace {
None,
MZone1,
MZone2,
MZone3,
MZone4,
MZone5,
MZone6,
MZone7,
SZone1,
SZone2,
SZone3,
SZone4,
SZone5,
SZone6,
SZone7,
SZone8,
OpMZone1,
OpMZone2,
OpMZone3,
OpMZone4,
OpMZone5,
OpMZone6,
OpMZone7,
OpSZone1,
OpSZone2,
OpSZone3,
OpSZone4,
OpSZone5,
OpSZone6,
OpSZone7,
OpSZone8,
};
static const ankerl::unordered_dense::map<std::string, uint8_t> cmd_place2id =
make_ids(std::vector<std::string>(
{"m1", "m2", "m3", "m4", "m5", "m6", "m7", "s1",
"s2", "s3", "s4", "s5", "s6", "s7", "s8", "om1",
"om2", "om3", "om4", "om5", "om6", "om7", "os1", "os2",
"os3", "os4", "os5", "os6", "os7", "os8"}),
1);
DEFINE_X_TO_ID_FUN(cmd_place_to_id, cmd_place2id)
inline std::vector<ActionPlace> flag_to_usable_places(
uint32_t flag, bool reverse = false) {
std::vector<ActionPlace> places;
for (int j = 0; j < 4; j++) {
uint32_t value = (flag >> (j * 8)) & 0xff;
for (int i = 0; i < 8; i++) {
bool avail = (value & (1 << i)) == 0;
if (reverse) {
avail = !avail;
}
if (avail) {
ActionPlace place;
if (j == 0) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::MZone1));
} else if (j == 1) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::SZone1));
} else if (j == 2) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::OpMZone1));
} else if (j == 3) {
place = static_cast<ActionPlace>(i + static_cast<int>(ActionPlace::OpSZone1));
}
places.push_back(place);
}
}
}
return places;
}
inline std::string action_place_to_string(ActionPlace place) {
int i = static_cast<int>(place);
if (i == 0) {
return "None";
}
else if (i >= static_cast<int>(ActionPlace::MZone1) && i <= static_cast<int>(ActionPlace::MZone7)) {
return fmt::format("m{}", i - static_cast<int>(ActionPlace::MZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::SZone1) && i <= static_cast<int>(ActionPlace::SZone8)) {
return fmt::format("s{}", i - static_cast<int>(ActionPlace::SZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::OpMZone1) && i <= static_cast<int>(ActionPlace::OpMZone7)) {
return fmt::format("om{}", i - static_cast<int>(ActionPlace::OpMZone1) + 1);
}
else if (i >= static_cast<int>(ActionPlace::OpSZone1) && i <= static_cast<int>(ActionPlace::OpSZone8)) {
return fmt::format("os{}", i - static_cast<int>(ActionPlace::OpSZone1) + 1);
}
else {
return "Unknown";
}
}
inline std::pair<uint8_t, uint8_t> float_transform(int x) {
......@@ -807,6 +978,89 @@ using PlayerId = uint8_t;
using CardCode = uint32_t;
using CardId = uint16_t;
const int DESCRIPTION_LIMIT = 10000;
const int CARD_EFFECT_OFFSET = 10010;
class LegalAction {
public:
std::string spec_ = "";
ActionAct act_ = ActionAct::None;
ActionPhase phase_ = ActionPhase::None;
bool finish_ = false;
uint8_t position_ = 0;
int effect_ = -1;
uint8_t number_ = 0;
ActionPlace place_ = ActionPlace::None;
uint8_t attribute_ = 0;
int spec_index_ = 0;
CardId cid_ = 0;
int msg_ = 0;
static LegalAction from_spec(const std::string &spec) {
LegalAction la;
la.spec_ = spec;
return la;
}
static LegalAction act_spec(ActionAct act, const std::string &spec) {
LegalAction la;
la.act_ = act;
la.spec_ = spec;
return la;
}
static LegalAction finish() {
LegalAction la;
la.finish_ = true;
return la;
}
static LegalAction cancel() {
LegalAction la;
la.act_ = ActionAct::Cancel;
return la;
}
static LegalAction activate_spec(int effect_idx, const std::string &spec) {
LegalAction la;
la.act_ = ActionAct::Activate;
la.effect_ = effect_idx;
la.spec_ = spec;
return la;
}
static LegalAction phase(ActionPhase phase) {
LegalAction la;
la.phase_ = phase;
return la;
}
static LegalAction number(uint8_t number) {
LegalAction la;
la.number_ = number;
return la;
}
static LegalAction place(ActionPlace place) {
LegalAction la;
la.place_ = place;
return la;
}
static LegalAction attribute(int attribute) {
LegalAction la;
la.attribute_ = attribute;
return la;
}
};
class SpecInfo {
public:
uint16_t index;
CardId cid;
};
class Card {
friend class YGOProEnv;
......@@ -874,42 +1128,23 @@ public:
return get_spec(player != controler_);
}
uint32_t get_spec_code(PlayerId player) const {
return ls_to_spec_code(location_, sequence_, position_,
player != controler_);
}
std::string get_position() const { return position_to_string(position_); }
std::string get_effect_description(uint32_t desc,
bool existing = false) const {
std::string s;
bool e = false;
auto code = code_;
if (desc > 10000) {
code = desc >> 4;
std::string get_effect_description(CardCode code, int effect_idx) const {
if (code == 0) {
return get_system_string(effect_idx);
}
uint32_t offset = desc - code_ * 16;
bool in_range = (offset >= 0) && (offset < strings_.size());
std::string str = "";
if (in_range) {
str = ltrim(strings_[offset]);
if (effect_idx == 0) {
return "default";
}
if (in_range || desc == 0) {
if ((desc == 0) || str.empty()) {
s = "Activate " + name_ + ".";
} else {
s = name_ + " (" + str + ")";
e = true;
}
} else {
s = get_system_string(desc);
if (!s.empty()) {
e = true;
}
effect_idx -= CARD_EFFECT_OFFSET;
if (effect_idx < 0) {
throw std::runtime_error(
fmt::format("Invalid effect index: {}", effect_idx));
}
if (existing && !e) {
s = "";
auto s = strings_[effect_idx];
if (s.empty()) {
return "effect " + std::to_string(effect_idx);
}
return s;
}
......@@ -1222,7 +1457,7 @@ public:
const int &init_lp() const { return init_lp_; }
virtual int think(const std::vector<std::string> &options) = 0;
virtual int think(const std::vector<LegalAction> &actions) = 0;
};
class GreedyAI : public Player {
......@@ -1232,7 +1467,7 @@ public:
bool verbose = false)
: Player(nickname, init_lp, duel_player, verbose) {}
int think(const std::vector<std::string> &options) override { return 0; }
int think(const std::vector<LegalAction> &actions) override { return 0; }
};
class RandomAI : public Player {
......@@ -1246,8 +1481,8 @@ public:
: Player(nickname, init_lp, duel_player, verbose), gen_(seed),
dist_(0, max_options - 1) {}
int think(const std::vector<std::string> &options) override {
return dist_(gen_) % options.size();
int think(const std::vector<LegalAction> &actions) override {
return dist_(gen_) % actions.size();
}
};
......@@ -1258,17 +1493,17 @@ public:
bool verbose = false)
: Player(nickname, init_lp, duel_player, verbose) {}
int think(const std::vector<std::string> &options) override {
int think(const std::vector<LegalAction> &actions) override {
while (true) {
std::string input = getline();
if (input == "quit") {
exit(0);
}
auto it = std::find(options.begin(), options.end(), input);
if (it != options.end()) {
return std::distance(options.begin(), it);
int idx = std::stoi(input) - 1;
if (idx >= 0 && idx < actions.size()) {
return idx;
} else {
fmt::println("{} Choose from {}", duel_player_, options);
fmt::println("{} Choose from {} actions", duel_player_, actions.size());
}
}
}
......@@ -1286,7 +1521,7 @@ public:
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
int n_action_feats = 13;
int n_action_feats = 12;
return MakeDict(
"obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"obs:global_"_.Bind(Spec<uint8_t>({23})),
......@@ -1393,7 +1628,7 @@ protected:
int turn_count_;
int msg_;
std::vector<std::string> options_;
std::vector<LegalAction> legal_actions_;
PlayerId to_play_;
std::function<void(int)> callback_;
......@@ -1423,9 +1658,10 @@ protected:
const int n_history_actions_;
// circular buffer for history actions
TArray<uint8_t> history_actions_;
int ha_p_ = 0;
std::vector<CardId> h_card_ids_;
TArray<uint8_t> history_actions_1_;
TArray<uint8_t> history_actions_2_;
int ha_p_1_ = 0;
int ha_p_2_ = 0;
std::unordered_set<std::string> revealed_;
......@@ -1487,8 +1723,9 @@ public:
int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
h_card_ids_.resize(max_options);
history_actions_ = TArray<uint8_t>(Array(
history_actions_1_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
history_actions_2_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
}
......@@ -1560,8 +1797,10 @@ public:
turn_count_ = 0;
ms_idx_ = -1;
history_actions_.Zero();
ha_p_ = 0;
history_actions_1_.Zero();
history_actions_2_.Zero();
ha_p_1_ = 0;
ha_p_2_ = 0;
clock_t _start = clock();
......@@ -1720,7 +1959,7 @@ public:
if (ms_mode_ == 0) {
for (int j = 0; j < ms_specs_.size(); ++j) {
const auto &spec = ms_specs_[j];
options_.push_back(spec);
legal_actions_.push_back(LegalAction::from_spec(spec));
}
} else {
ms_combs_ = combs;
......@@ -1729,22 +1968,23 @@ public:
}
void handle_multi_select() {
options_ = {};
legal_actions_.clear();
if (ms_mode_ == 0) {
for (int j = 0; j < ms_specs_.size(); ++j) {
if (ms_spec2idx_.find(ms_specs_[j]) != ms_spec2idx_.end()) {
options_.push_back(ms_specs_[j]);
legal_actions_.push_back(
LegalAction::from_spec(ms_specs_[j]));
}
}
if (ms_idx_ == ms_max_ - 1) {
if (ms_idx_ >= ms_min_) {
options_.push_back("f");
legal_actions_.push_back(LegalAction::finish());
}
callback_ = [this](int idx) {
_callback_multi_select(idx, true);
};
} else if (ms_idx_ >= ms_min_) {
options_.push_back("f");
legal_actions_.push_back(LegalAction::finish());
callback_ = [this](int idx) {
_callback_multi_select(idx, false);
};
......@@ -1766,7 +2006,7 @@ public:
if (it != ms_spec2idx_.end()) {
return it->second;
}
// TODO: find the root cause
// TODO(2): find the root cause
// print ms_spec2idx
show_deck(0);
show_deck(1);
......@@ -1783,11 +2023,15 @@ public:
}
void _callback_multi_select_2(int idx) {
const auto &option = options_[idx];
idx = get_ms_spec_idx(option);
const auto &action = legal_actions_[idx];
idx = get_ms_spec_idx(action.spec_);
if (idx == -1) {
// TODO: find the root cause
fmt::println("options: {}, idx: {}, option: {}", options_, idx, option);
// TODO(2): find the root cause
std::vector<std::string> specs;
for (const auto &la : legal_actions_) {
specs.push_back(la.spec_);
}
fmt::println("specs: {}, idx: {}, spec: {}", specs, idx, action.spec_);
throw std::runtime_error("Spec not found");
}
ms_r_idxs_.push_back(idx);
......@@ -1814,7 +2058,7 @@ public:
}
for (auto &i : comb) {
const auto &spec = ms_specs_[i];
options_.push_back(spec);
legal_actions_.push_back(LegalAction::from_spec(spec));
}
}
......@@ -1831,17 +2075,21 @@ public:
}
void _callback_multi_select(int idx, bool finish) {
const auto &option = options_[idx];
const auto &action = legal_actions_[idx];
// fmt::println("Select card: {}, finish: {}", option, finish);
if (option == "f") {
if (action.finish_) {
finish = true;
} else {
idx = get_ms_spec_idx(option);
idx = get_ms_spec_idx(action.spec_);
if (idx != -1) {
ms_r_idxs_.push_back(idx);
} else {
// TODO: find the root cause
fmt::println("options: {}, idx: {}, option: {}", options_, idx, option);
// TODO(2): find the root cause
std::vector<std::string> specs;
for (const auto &la : legal_actions_) {
specs.push_back(la.spec_);
}
fmt::println("specs: {}, idx: {}, spec: {}", specs, idx, action.spec_);
ms_idx_ = -1;
resp_buf_[0] = ms_min_;
for (int i = 0; i < ms_min_; ++i) {
......@@ -1860,27 +2108,27 @@ public:
YGO_SetResponseb(pduel_, resp_buf_);
} else {
ms_idx_++;
ms_spec2idx_.erase(option);
ms_spec2idx_.erase(action.spec_);
}
}
void update_h_card_ids(PlayerId player, int idx) {
h_card_ids_[idx] = parse_card_id(options_[idx], player);
}
void update_history_actions(PlayerId player, int idx) {
if ((msg_ == MSG_SELECT_CHAIN) & (options_[idx][0] == 'c')) {
void update_history_actions(PlayerId player, const LegalAction& action) {
if (action.act_ == ActionAct::Cancel) {
return;
}
ha_p_--;
if (ha_p_ < 0) {
ha_p_ = n_history_actions_ - 1;
auto& ha_p = player == 0 ? ha_p_1_ : ha_p_2_;
auto& history_actions = player == 0 ? history_actions_1_ : history_actions_2_;
ha_p--;
if (ha_p < 0) {
ha_p = n_history_actions_ - 1;
}
history_actions_[ha_p_].Zero();
_set_obs_action(history_actions_, ha_p_, msg_, options_[idx], {},
h_card_ids_[idx]);
history_actions_[ha_p_](13) = static_cast<uint8_t>(player);
history_actions_[ha_p_](14) = static_cast<uint8_t>(turn_count_);
history_actions[ha_p].Zero();
_set_obs_action(history_actions, ha_p, action);
// Spec index not available in history actions
history_actions[ha_p](0) = 0;
// history_actions[ha_p](12) = static_cast<uint8_t>(player);
history_actions[ha_p](12) = static_cast<uint8_t>(turn_count_);
history_actions[ha_p](13) = static_cast<uint8_t>(phase_to_id(current_phase_));
}
void show_deck(const std::vector<CardCode> &deck, const std::string &prefix) const {
......@@ -1910,18 +2158,18 @@ public:
}
void show_history_actions(PlayerId player) const {
const auto &ha = history_actions_;
const auto &ha = player == 0 ? history_actions_1_ : history_actions_2_;
// print card ids of history actions
for (int i = 0; i < n_history_actions_; ++i) {
fmt::print("history {}\n", i);
uint8_t msg_id = uint8_t(ha(i, 2));
uint8_t msg_id = uint8_t(ha(i, 3));
int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg));
uint8_t v1 = ha(i, 0);
uint8_t v2 = ha(i, 1);
uint8_t v1 = ha(i, 1);
uint8_t v2 = ha(i, 2);
CardId card_id = (static_cast<CardId>(v1) << 8) + static_cast<CardId>(v2);
fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) {
for (int j = 4; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j)));
}
fmt::print("\n");
......@@ -1933,7 +2181,7 @@ public:
int idx = action["action"_];
callback_(idx);
update_history_actions(to_play_, idx);
update_history_actions(to_play_, legal_actions_[idx]);
PlayerId player = to_play_;
......@@ -2012,10 +2260,10 @@ public:
}
private:
using SpecIndex = ankerl::unordered_dense::map<std::string, uint16_t>;
using SpecInfos = ankerl::unordered_dense::map<std::string, SpecInfo>;
std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecIndex spec2index;
std::tuple<SpecInfos, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
......@@ -2054,18 +2302,23 @@ private:
hide = false;
}
}
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
_set_obs_card_(f_cards, offset, c, hide);
offset++;
spec2index[spec] = static_cast<uint16_t>(offset);
spec_infos[spec] = {static_cast<uint16_t>(offset), card_id};
}
}
}
}
return {spec2index, loc_n_cards};
return {spec_infos, loc_n_cards};
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide) {
bool hide, CardId card_id = 0) {
// check offset exceeds max_cards
uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY;
......@@ -2077,7 +2330,6 @@ private:
}
if (!hide) {
auto card_id = c_get_card_id(c.code_);
f_cards(offset, 0) = static_cast<uint8_t>(card_id >> 8);
f_cards(offset, 1) = static_cast<uint8_t>(card_id & 0xff);
}
......@@ -2148,17 +2400,10 @@ private:
}
}
void _set_obs_action_spec(TArray<uint8_t> &feat, int i,
const std::string &spec,
const SpecIndex &spec2index,
CardId card_id = 0) {
uint16_t idx;
if (spec2index.empty()) {
idx = card_id;
} else {
auto it = spec2index.find(spec);
if (it == spec2index.end()) {
// TODO: find the root cause
const SpecInfo& find_spec_info(SpecInfos &spec_infos, const std::string &spec) {
auto it = spec_infos.find(spec);
if (it == spec_infos.end()) {
// TODO(2): find the root cause
// print spec2index
show_deck(0);
show_deck(1);
......@@ -2166,135 +2411,111 @@ private:
show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec2index) {
fmt::print("{}: {}, ", k, v);
for (auto &[k, v] : spec_infos) {
fmt::print("{}: {} {}, ", k, v.index, v.cid);
}
fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec);
idx = 1;
} else {
idx = it->second;
}
spec_infos[spec] = {1, 1};
return spec_infos[spec];
}
feat(i, 0) = static_cast<uint8_t>(idx >> 8);
feat(i, 1) = static_cast<uint8_t>(idx & 0xff);
}
void _set_obs_action_msg(TArray<uint8_t> &feat, int i, int msg) {
feat(i, 2) = msg_to_id(msg);
return it->second;
}
void _set_obs_action_act(TArray<uint8_t> &feat, int i, char act,
uint8_t act_offset = 0) {
feat(i, 3) = cmd_act_to_id(act) + act_offset;
void _set_obs_action_spec(
TArray<uint8_t> &feat, int i, int idx) {
feat(i, 0) = static_cast<uint8_t>(idx);
}
void _set_obs_action_yesno(TArray<uint8_t> &feat, int i, char yesno) {
feat(i, 4) = cmd_yesno_to_id(yesno);
void _set_obs_action_card_id(
TArray<uint8_t> &feat, int i, CardId cid) {
feat(i, 1) = static_cast<uint8_t>(cid >> 8);
feat(i, 2) = static_cast<uint8_t>(cid & 0xff);
}
void _set_obs_action_phase(TArray<uint8_t> &feat, int i, char phase) {
feat(i, 5) = cmd_phase_to_id(phase);
void _set_obs_action_msg(TArray<uint8_t> &feat, int i, int msg) {
feat(i, 3) = msg_to_id(msg);
}
void _set_obs_action_cancel(TArray<uint8_t> &feat, int i) {
feat(i, 6) = 1;
void _set_obs_action_act(TArray<uint8_t> &feat, int i, ActionAct act) {
feat(i, 4) = static_cast<uint8_t>(act);
}
void _set_obs_action_finish(TArray<uint8_t> &feat, int i) {
feat(i, 7) = 1;
feat(i, 5) = 1;
}
void _set_obs_action_effect(TArray<uint8_t> &feat, int i, int effect) {
// 0: None
// 1: default
// 2-15: card effect
// 16+: system
if (effect == -1) {
effect = 0;
} else if (effect == 0) {
effect = 1;
} else if (effect >= CARD_EFFECT_OFFSET) {
effect = effect - CARD_EFFECT_OFFSET + 2;
} else {
effect = system_string_to_id(effect);
}
feat(i, 6) = static_cast<uint8_t>(effect);
}
void _set_obs_action_position(TArray<uint8_t> &feat, int i, char position) {
position = 1 << (position - '1');
feat(i, 8) = position_to_id(position);
void _set_obs_action_phase(TArray<uint8_t> &feat, int i, ActionPhase phase){
feat(i, 7) = static_cast<uint8_t>(phase);
}
void _set_obs_action_option(TArray<uint8_t> &feat, int i, char option) {
feat(i, 9) = option - '0';
void _set_obs_action_position(TArray<uint8_t> &feat, int i, uint8_t position) {
feat(i, 8) = position_to_id(position);
}
void _set_obs_action_number(TArray<uint8_t> &feat, int i, char number) {
feat(i, 10) = number - '0';
void _set_obs_action_number(TArray<uint8_t> &feat, int i, uint8_t number) {
feat(i, 9) = number;
}
void _set_obs_action_place(TArray<uint8_t> &feat, int i, const std::string &spec) {
feat(i, 11) = cmd_place_to_id(spec);
void _set_obs_action_place(TArray<uint8_t> &feat, int i, ActionPlace place) {
feat(i, 10) = static_cast<uint8_t>(place);
}
void _set_obs_action_attrib(TArray<uint8_t> &feat, int i, uint8_t attrib) {
feat(i, 12) = attribute_to_id(attrib);
feat(i, 11) = attribute_to_id(attrib);
}
void _set_obs_action(TArray<uint8_t> &feat, int i, int msg,
const std::string &option, const SpecIndex &spec2index,
CardId card_id) {
void _set_obs_action(TArray<uint8_t> &feat, int i, const LegalAction &action) {
auto msg = action.msg_;
_set_obs_action_msg(feat, i, msg);
if (msg == MSG_SELECT_IDLECMD) {
if (option == "b" || option == "e") {
_set_obs_action_phase(feat, i, option[0]);
} else {
auto act = option[0];
auto spec = option.substr(2);
uint8_t offset = 0;
int n = spec.size();
if (act == 'v' && std::isalpha(spec[n - 1])) {
offset = spec[n - 1] - 'a';
spec = spec.substr(0, n - 1);
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CHAIN) {
if (option[0] == 'c') {
_set_obs_action_cancel(feat, i);
} else {
char act = 'v';
auto spec = option;
uint8_t offset = 0;
auto n = spec.size();
if (std::isalpha(spec[n - 1])) {
offset = spec[n - 1] - 'a';
spec = spec.substr(0, n - 1);
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CARD || msg == MSG_SELECT_TRIBUTE ||
msg == MSG_SELECT_SUM || msg == MSG_SELECT_UNSELECT_CARD) {
if (option[0] == 'f') {
_set_obs_action_card_id(feat, i, action.cid_);
if (msg == MSG_SELECT_CARD || msg == MSG_SELECT_TRIBUTE ||
msg == MSG_SELECT_SUM || msg == MSG_SELECT_UNSELECT_CARD) {
if (action.finish_) {
_set_obs_action_finish(feat, i);
} else {
_set_obs_action_spec(feat, i, option, spec2index, card_id);
_set_obs_action_spec(feat, i, action.spec_index_);
}
} else if (msg == MSG_SELECT_POSITION) {
_set_obs_action_position(feat, i, option[0]);
_set_obs_action_position(feat, i, action.position_);
} else if (msg == MSG_SELECT_EFFECTYN) {
auto spec = option.substr(2);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
_set_obs_action_yesno(feat, i, option[0]);
} else if (msg == MSG_SELECT_YESNO) {
_set_obs_action_yesno(feat, i, option[0]);
} else if (msg == MSG_SELECT_BATTLECMD) {
if (option == "m" || option == "e") {
_set_obs_action_phase(feat, i, option[0]);
} else {
auto act = option[0];
auto spec = option.substr(2);
_set_obs_action_act(feat, i, act);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_OPTION) {
_set_obs_action_option(feat, i, option[0]);
_set_obs_action_spec(feat, i, action.spec_index_);
_set_obs_action_act(feat, i, action.act_);
_set_obs_action_effect(feat, i, action.effect_);
} else if (msg == MSG_SELECT_YESNO || msg == MSG_SELECT_OPTION) {
_set_obs_action_act(feat, i, action.act_);
_set_obs_action_effect(feat, i, action.effect_);
} else if (
msg == MSG_SELECT_BATTLECMD ||
msg == MSG_SELECT_IDLECMD ||
msg == MSG_SELECT_CHAIN) {
_set_obs_action_phase(feat, i, action.phase_);
_set_obs_action_spec(feat, i, action.spec_index_);
_set_obs_action_act(feat, i, action.act_);
_set_obs_action_effect(feat, i, action.effect_);
} else if (msg == MSG_SELECT_PLACE || msg_ == MSG_SELECT_DISFIELD) {
_set_obs_action_place(feat, i, option);
_set_obs_action_place(feat, i, action.place_);
} else if (msg == MSG_ANNOUNCE_ATTRIB) {
_set_obs_action_attrib(feat, i, 1 << (option[0] - '1'));
_set_obs_action_attrib(feat, i, action.attribute_);
} else if (msg == MSG_ANNOUNCE_NUMBER) {
_set_obs_action_number(feat, i, option[0]);
_set_obs_action_number(feat, i, action.number_);
} else {
throw std::runtime_error("Unsupported message " + std::to_string(msg));
}
......@@ -2302,49 +2523,42 @@ private:
CardId spec_to_card_id(const std::string &spec, PlayerId player) {
int offset = 0;
// TODO: possible info leak
bool opponent = false;
if (spec[0] == 'o') {
player = 1 - player;
opponent = true;
offset++;
}
auto [loc, seq, pos] = spec_to_ls(spec.substr(offset));
return c_get_card_id(get_card_code(player, loc, seq));
}
CardId parse_card_id(const std::string &option, PlayerId player) {
CardId card_id = 0;
if (msg_ == MSG_SELECT_IDLECMD) {
if (!(option == "b" || option == "e")) {
auto n = option.size();
if (std::isalpha(option[n - 1])) {
card_id = spec_to_card_id(option.substr(2, n - 3), player);
} else {
card_id = spec_to_card_id(option.substr(2), player);
}
if (opponent) {
bool hidden_for_opponent = true;
if (
loc == LOCATION_MZONE || loc == LOCATION_SZONE ||
loc == LOCATION_GRAVE || loc == LOCATION_REMOVED) {
hidden_for_opponent = false;
}
} else if (msg_ == MSG_SELECT_CHAIN) {
if (option != "c") {
card_id = spec_to_card_id(option, player);
if (revealed_.size() != 0) {
hidden_for_opponent = false;
}
} else if (msg_ == MSG_SELECT_CARD || msg_ == MSG_SELECT_TRIBUTE ||
msg_ == MSG_SELECT_SUM || msg_ == MSG_SELECT_UNSELECT_CARD) {
if (option[0] != 'f') {
card_id = spec_to_card_id(option, player);
if (hidden_for_opponent) {
return 0;
}
} else if (msg_ == MSG_SELECT_EFFECTYN) {
card_id = spec_to_card_id(option.substr(2), player);
} else if (msg_ == MSG_SELECT_BATTLECMD) {
if (!(option == "m" || option == "e")) {
card_id = spec_to_card_id(option.substr(2), player);
Card c = get_card(player, loc, seq);
bool hide = c.position_ & POS_FACEDOWN;
if (revealed_.find(spec) != revealed_.end()) {
hide = false;
}
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
}
return card_id;
return c_get_card_id(get_card_code(player, loc, seq));
}
void _set_obs_actions(TArray<uint8_t> &feat, const SpecIndex &spec2index,
int msg, const std::vector<std::string> &options) {
for (int i = 0; i < options.size(); ++i) {
_set_obs_action(feat, i, msg, options[i], spec2index, 0);
void _set_obs_actions(TArray<uint8_t> &feat, const std::vector<LegalAction> &actions) {
for (int i = 0; i < actions.size(); ++i) {
_set_obs_action(feat, i, actions[i]);
}
}
......@@ -2451,7 +2665,7 @@ private:
void WriteState(float reward, int win_reason = 0) {
State state = Allocate();
int n_options = options_.size();
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
......@@ -2463,62 +2677,69 @@ private:
return;
}
auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
options_.resize(max_options());
legal_actions_.resize(max_options());
}
// print spec2index
// for (auto const& [key, val] : spec2index) {
// fmt::println("{} {}", key, val);
// }
_set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
n_options = options_.size();
n_options = legal_actions_.size();
state["info:num_options"_] = n_options;
// update_h_card_ids from state
for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2);
if (spec_index == 0) {
h_card_ids_[i] = 0;
} else {
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
h_card_ids_[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
auto &action = legal_actions_[i];
action.msg_ = msg_;
const auto &spec = action.spec_;
if (!spec.empty()) {
const auto& spec_info = find_spec_info(spec_infos, spec);
action.spec_index_ = spec_info.index;
if (action.cid_ == 0) {
action.cid_ = spec_info.cid;
}
}
}
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions
int offset = n_history_actions_ - ha_p_;
int n_h_action_feats = history_actions_.Shape()[1];
auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions_[ha_p_].Data(), n_h_action_feats * offset);
(uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions_.Data(), n_h_action_feats * ha_p_);
(uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 2)) == 0) {
if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break;
}
state["obs:h_actions_"_](i, 13) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 13)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 14)));
state["obs:h_actions_"_](i, 14) = static_cast<uint8_t>(turn_diff);
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
}
}
void show_decision(int idx) {
fmt::println("Player {} chose \"{}\" in {}", to_play_, options_[idx],
options_);
std::string s;
const auto& a = legal_actions_[idx];
if (!a.spec_.empty()) {
s = a.spec_;
} else if (a.place_ != ActionPlace::None) {
s = action_place_to_string(a.place_);
} else if (a.position_ != 0) {
s = position_to_string(a.position_);
} else {
s = fmt::format("{}", a);
}
fmt::print("Player {} chose \"{}\" in {}\n", to_play_, s, legal_actions_);
}
std::tuple<std::vector<CardCode>, std::vector<CardCode>, std::string>
......@@ -2581,15 +2802,19 @@ private:
handle_multi_select();
} else {
handle_message();
if (options_.empty()) {
if (legal_actions_.empty()) {
continue;
}
}
if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) {
if (options_.size() == 1) {
if (legal_actions_.size() == 1) {
callback_(0);
update_h_card_ids(to_play_, 0);
update_history_actions(to_play_, 0);
auto la = legal_actions_[0];
la.msg_ = msg_;
if (la.cid_ == 0 && !la.spec_.empty()) {
la.cid_ = spec_to_card_id(la.spec_, to_play_);
}
update_history_actions(to_play_, la);
if (verbose_) {
show_decision(0);
}
......@@ -2597,7 +2822,7 @@ private:
return;
}
} else {
auto idx = players_[to_play_]->think(options_);
auto idx = players_[to_play_]->think(legal_actions_);
callback_(idx);
if (verbose_) {
show_decision(idx);
......@@ -2606,7 +2831,7 @@ private:
}
}
done_ = true;
options_.clear();
legal_actions_.clear();
}
uint8_t read_u8() { return data_[dp_++]; }
......@@ -2653,7 +2878,12 @@ private:
int32_t bl = YGO_QueryCard(pduel_, player, loc, seq, flags, query_buf_);
qdp_ = 0;
if (bl <= 0) {
throw std::runtime_error("[get_card] Invalid card (bl <= 0)");
show_deck(0);
show_deck(1);
show_turn();
show_buffer();
auto s = fmt::format("[get_card] Invalid card (bl <= 0), player: {}, loc: {}, seq: {}", player, loc, seq);
throw std::runtime_error(s);
}
uint32_t f = q_read_u32();
if (f == LEN_EMPTY) {
......@@ -2728,7 +2958,7 @@ private:
c.attack_ = q_read_u32();
c.defense_ = q_read_u32();
// TODO: equip_target
// TODO(2): equip_target
if (f & QUERY_EQUIP_CARD) {
q_read_u32();
}
......@@ -2744,7 +2974,7 @@ private:
cards.push_back(c_);
}
// TODO: counters
// TODO(2): counters
uint32_t n_counters = q_read_u32();
for (int i = 0; i < n_counters; ++i) {
if (i == 0) {
......@@ -2803,7 +3033,7 @@ private:
auto controller = read_u8();
auto loc = read_u8();
auto seq = read_u8();
uint32_t data = -1;
uint32_t data = 0;
if (extra) {
if (extra8) {
data = read_u8();
......@@ -2816,6 +3046,23 @@ private:
return card_specs;
}
std::tuple<CardCode, int> unpack_desc(CardCode code, uint32_t desc) {
if (desc < DESCRIPTION_LIMIT) {
return {0, desc};
}
CardCode code_ = desc >> 4;
int idx = desc & 0xf;
if (idx < 0 || idx >= 14) {
fmt::print("Code: {}, Code_: {}, Desc: {}\n", code, code_, desc);
show_deck(0);
show_deck(1);
show_buffer();
show_turn();
throw std::runtime_error("Invalid effect index: " + std::to_string(idx));
}
return {code_, idx + CARD_EFFECT_OFFSET};
}
std::string cardlist_info_for_player(const Card &card, PlayerId pl) {
std::string spec = card.get_spec(pl);
if (card.location_ == LOCATION_DECK) {
......@@ -2833,7 +3080,7 @@ private:
// 3. update to_play_ and options_ if need action
void handle_message() {
msg_ = int(data_[dp_++]);
options_ = {};
legal_actions_ = {};
if (verbose_) {
fmt::println("Message {}, length {}, dp {}", msg_to_string(msg_), dl_, dp_);
......@@ -3097,11 +3344,11 @@ private:
uint8_t pos = read_u8();
uint8_t type = read_u8();
uint32_t value = read_u32();
Card card = get_card(player, loc, seq);
if (card.code_ == 0) {
return;
}
if (type == CHINT_RACE) {
Card card = get_card(player, loc, seq);
if (card.code_ == 0) {
return;
}
std::string races_str = "TODO";
for (PlayerId pl = 0; pl < 2; pl++) {
players_[pl]->notify(fmt::format("{} ({}) selected {}.",
......@@ -3109,6 +3356,10 @@ private:
races_str));
}
} else if (type == CHINT_ATTRIBUTE) {
Card card = get_card(player, loc, seq);
if (card.code_ == 0) {
return;
}
std::string attributes_str = "TODO";
for (PlayerId pl = 0; pl < 2; pl++) {
players_[pl]->notify(fmt::format("{} ({}) selected {}.",
......@@ -3229,7 +3480,7 @@ private:
return;
}
dp_ += 6;
// TODO: implement output
// TODO(3): implement output
} else if (msg_ == MSG_CARD_TARGET) {
if (!verbose_) {
dp_ = dl_;
......@@ -3301,7 +3552,7 @@ private:
players_[pl]->notify(str);
}
} else if (msg_ == MSG_SORT_CARD) {
// TODO: implement action
// TODO(3): implement action
if (!verbose_) {
dp_ = dl_;
resp_buf_[0] = 255;
......@@ -3374,7 +3625,7 @@ private:
auto pl = players_[player];
PlayerId op_id = 1 - player;
auto op = players_[op_id];
// TODO: counter type to string
// TODO(3): counter type to string
pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
} else if (msg_ == MSG_REMOVE_COUNTER) {
......@@ -3406,7 +3657,7 @@ private:
dp_ = dl_;
return;
}
// TODO: implement output
// TODO(3): implement output
dp_ = dl_;
} else if (msg_ == MSG_SHUFFLE_DECK) {
if (!verbose_) {
......@@ -3699,52 +3950,64 @@ private:
if (verbose_) {
pl->notify("Battle menu:");
}
for (const auto [code, spec, data] : activatable) {
// TODO: Add effect description to indicate which effect is being activated
options_.push_back("v " + spec);
for (const auto [code_t, spec, desc] : activatable) {
CardCode code = code_t;
if(code & 0x80000000) {
code &= 0x7fffffff;
}
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
legal_actions_.push_back(la);
if (verbose_) {
auto [loc, seq, pos] = spec_to_ls(spec);
auto c = get_card(player, loc, seq);
pl->notify("v " + spec + ": activate " + c.name_ + " (" +
std::to_string(c.attack_) + "/" +
std::to_string(c.defense_) + ")");
auto c = c_get_card(code);
int cmd_idx = legal_actions_.size();
std::string s = fmt::format(
"{}: activate {}({}) [{}/{}] ({})",
cmd_idx, c.name_, spec, c.attack_, c.defense_, c.get_effect_description(code_d, eff_idx));
}
}
for (const auto [code, spec, data] : attackable) {
// TODO: add this as feature
bool direct_attackable = data & 0x1;
options_.push_back("a " + spec);
auto act = direct_attackable ? ActionAct::DirectAttack : ActionAct::Attack;
legal_actions_.push_back(
LegalAction::act_spec(act, spec));
if (verbose_) {
auto [loc, seq, pos] = spec_to_ls(spec);
auto c = get_card(player, loc, seq);
std::string s;
auto [controller, loc, seq, pos] = spec_to_ls(player, spec);
auto c = get_card(controller, loc, seq);
int cmd_idx = legal_actions_.size();
auto attack_str = direct_attackable ? "direct attack" : "attack";
std::string s = fmt::format(
"{}: {} {}({}) ", cmd_idx, attack_str, c.name_, spec);
if (c.type_ & TYPE_LINK) {
s = "a " + spec + ": " + c.name_ + " (" +
std::to_string(c.attack_) + ")";
s += fmt::format("[{}]", c.attack_);
} else {
s = "a " + spec + ": " + c.name_ + " (" +
std::to_string(c.attack_) + "/" +
std::to_string(c.defense_) + ")";
}
if (direct_attackable) {
s += " direct attack";
} else {
s += " attack";
s += fmt::format("[{}/{}]", c.attack_, c.defense_);
}
pl->notify(s);
}
}
if (to_m2) {
options_.push_back("m");
legal_actions_.push_back(
LegalAction::phase(ActionPhase::Main2));
int cmd_idx = legal_actions_.size();
if (verbose_) {
pl->notify("m: Main phase 2.");
pl->notify(fmt::format("{}: Main phase 2.", cmd_idx));
}
}
if (to_ep) {
if (!to_m2) {
options_.push_back("e");
legal_actions_.push_back(
LegalAction::phase(ActionPhase::End));
int cmd_idx = legal_actions_.size();
if (verbose_) {
pl->notify("e: End phase.");
pl->notify(fmt::format("{}: End phase.", cmd_idx));
}
}
}
......@@ -3752,14 +4015,15 @@ private:
int n_attackables = attackable.size();
to_play_ = player;
callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) {
const auto &la = legal_actions_[idx];
if (idx < n_activatables) {
YGO_SetResponsei(pduel_, idx << 16);
} else if (idx < (n_activatables + n_attackables)) {
idx = idx - n_activatables;
YGO_SetResponsei(pduel_, (idx << 16) + 1);
} else if ((options_[idx] == "e") && to_ep) {
} else if ((la.phase_ == ActionPhase::End) && to_ep) {
YGO_SetResponsei(pduel_, 3);
} else if ((options_[idx] == "m") && to_m2) {
} else if ((la.phase_ == ActionPhase::Main2) && to_m2) {
YGO_SetResponsei(pduel_, 2);
} else {
throw std::runtime_error("Invalid option");
......@@ -3777,21 +4041,18 @@ private:
std::vector<std::string> select_specs;
select_specs.reserve(select_size);
if (verbose_) {
std::vector<Card> cards;
auto pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:");
for (int i = 0; i < select_size; ++i) {
auto code = read_u32();
auto loc = read_u32();
Card card = c_get_card(code);
card.set_location(loc);
cards.push_back(card);
}
auto pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:");
for (const auto &card : cards) {
auto spec = card.get_spec(player);
select_specs.push_back(spec);
pl->notify(spec + ": " + card.name_);
auto s = fmt::format("{}: {}({})", i + 1, card.name_, spec);
pl->notify(s);
}
} else {
for (int i = 0; i < select_size; ++i) {
......@@ -3807,22 +4068,22 @@ private:
auto unselect_size = read_u8();
// unselect not allowed (no regrets!)
// unselect not allowed (no regrets)
dp_ += 8 * unselect_size;
for (int j = 0; j < select_specs.size(); ++j) {
options_.push_back(select_specs[j]);
legal_actions_.push_back(LegalAction::from_spec(select_specs[j]));
}
if (finishable) {
options_.push_back("f");
legal_actions_.push_back(LegalAction::finish());
}
// cancelable and finishable not needed
to_play_ = player;
callback_ = [this](int idx) {
if (options_[idx] == "f") {
if (legal_actions_[idx].finish_) {
YGO_SetResponsei(pduel_, -1);
} else {
resp_buf_[0] = 1;
......@@ -3893,7 +4154,7 @@ private:
}
}
// TODO: use this when added to history actions
// TODO(1): use this when added to history actions
// if ((min == max) && (max == specs.size())) {
// resp_buf_[0] = specs.size();
// for (int i = 0; i < specs.size(); ++i) {
......@@ -3974,7 +4235,7 @@ private:
// combs = combinations_with_weight(release_params, min);
}
// TODO: use this when added to history actions
// TODO(1): use this when added to history actions
// if (max == specs.size()) {
// // tribute all
// resp_buf_[0] = specs.size();
......@@ -4126,25 +4387,18 @@ private:
// auto hint_timing = read_u32();
// auto other_timing = read_u32();
std::vector<Card> cards;
std::vector<CardCode> codes;
std::vector<uint32_t> descs;
std::vector<uint32_t> spec_codes;
std::vector<std::string> specs;
for (int i = 0; i < size; ++i) {
auto et = read_u8();
auto flag = read_u8();
CardCode code = read_u32();
if (verbose_) {
uint32_t loc = read_u32();
Card card = c_get_card(code);
card.set_location(loc);
cards.push_back(card);
spec_codes.push_back(card.get_spec_code(player));
} else {
PlayerId c = read_u8();
uint8_t loc = read_u8();
uint8_t seq = read_u8();
uint8_t pos = read_u8();
spec_codes.push_back(ls_to_spec_code(loc, seq, pos, c != player));
}
codes.push_back(code);
PlayerId c = read_u8();
uint8_t loc = read_u8();
uint8_t seq = read_u8();
uint8_t pos = read_u8();
specs.push_back(ls_to_spec(loc, seq, pos, c != player));
uint32_t desc = read_u32();
descs.push_back(desc);
}
......@@ -4168,58 +4422,42 @@ private:
op->seen_waiting_ = true;
}
std::vector<int> chain_index;
ankerl::unordered_dense::map<uint32_t, int> chain_counts;
ankerl::unordered_dense::map<uint32_t, int> chain_orders;
std::vector<std::string> chain_specs;
std::vector<std::string> effect_descs;
for (int i = 0; i < size; i++) {
chain_index.push_back(i);
chain_counts[spec_codes[i]] += 1;
if (verbose_) {
pl->notify("Select chain:");
}
for (int i = 0; i < size; i++) {
auto spec_code = spec_codes[i];
auto cs = code_to_spec(spec_code);
auto chain_count = chain_counts[spec_code];
if (chain_count > 1) {
// TODO: should use desc to indicate activate which effect
cs.push_back('a' + chain_orders[spec_code]);
}
chain_orders[spec_code]++;
chain_specs.push_back(cs);
if (verbose_) {
const auto &card = cards[i];
effect_descs.push_back(card.get_effect_description(descs[i], true));
CardCode code = codes[i];
uint32_t desc = descs[i];
auto spec = specs[i];
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
}
if (verbose_) {
if (forced) {
pl->notify("Select chain:");
} else {
pl->notify("Select chain (c to cancel):");
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
for (int i = 0; i < size; i++) {
const auto &effect_desc = effect_descs[i];
if (effect_desc.empty()) {
pl->notify(chain_specs[i] + ": " + cards[i].name_);
} else {
pl->notify(chain_specs[i] + " (" + cards[i].name_ +
"): " + effect_desc);
}
legal_actions_.push_back(la);
if (verbose_) {
auto c = c_get_card(code);
std::string s = fmt::format(
"{}: {}({}) ({})",
i + 1, c.name_, spec, c.get_effect_description(code_d, eff_idx));
pl->notify(s);
}
}
for (const auto &spec : chain_specs) {
options_.push_back(spec);
}
if (!forced) {
options_.push_back("c");
legal_actions_.push_back(LegalAction::cancel());
if (verbose_) {
pl->notify(fmt::format("{}: cancel", size + 1));
}
}
to_play_ = player;
callback_ = [this, forced](int idx) {
const auto &option = options_[idx];
if (option == "c") {
const auto &action = legal_actions_[idx];
if (action.act_ == ActionAct::Cancel) {
if (forced) {
fmt::print("cancel not allowed in forced chain\n");
YGO_SetResponsei(pduel_, 0);
......@@ -4232,58 +4470,76 @@ private:
};
} else if (msg_ == MSG_SELECT_YESNO) {
auto player = read_u8();
auto desc = read_u32();
auto [code, eff_idx] = unpack_desc(0, desc);
if (desc == 0) {
show_buffer();
auto s = fmt::format("Unknown desc {} in select_yesno", desc);
throw std::runtime_error(s);
}
auto la = LegalAction::activate_spec(eff_idx, "");
if (code != 0) {
la.cid_ = c_get_card_id(code);
}
legal_actions_.push_back(la);
if (verbose_) {
auto desc = read_u32();
auto pl = players_[player];
std::string opt;
if (desc > 10000) {
auto code = desc >> 4;
auto card = c_get_card(code);
auto opt_idx = desc & 0xf;
if (opt_idx < card.strings_.size()) {
opt = card.strings_[opt_idx];
std::string s;
if (code == 0) {
s = get_system_string(eff_idx);
} else {
Card c = c_get_card(code);
int cmd_idx = legal_actions_.size();
eff_idx -= CARD_EFFECT_OFFSET;
if (eff_idx >= c.strings_.size()) {
throw std::runtime_error(
fmt::format("Unknown effect {} of {}", eff_idx, c.name_));
}
if (opt.empty()) {
opt = "Unknown question from " + card.name_ + ". Yes or no?";
auto str = c.strings_[eff_idx];
if (str.empty()) {
str = "effect " + std::to_string(eff_idx);
}
} else {
opt = get_system_string(desc);
s = fmt::format("{} ({})", c.name_, str);
}
pl->notify(opt);
pl->notify("Please enter y or n.");
} else {
dp_ += 4;
pl->notify("1: " + s);
pl->notify("2: No");
}
options_ = {"y", "n"};
// TODO: maybe add card id to cancel
legal_actions_.push_back(LegalAction::cancel());
to_play_ = player;
callback_ = [this](int idx) {
if (idx == 0) {
YGO_SetResponsei(pduel_, 1);
} else if (idx == 1) {
YGO_SetResponsei(pduel_, 0);
} else {
throw std::runtime_error("Invalid option");
}
};
} else if (msg_ == MSG_SELECT_EFFECTYN) {
auto player = read_u8();
std::string spec;
CardCode code = read_u32();
auto ct = read_u8();
auto loc = read_u8();
auto seq = read_u8();
auto pos = read_u8();
auto desc = read_u32();
std::string spec = ls_to_spec(loc, seq, pos, ct != player);
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
legal_actions_.push_back(la);
if (verbose_) {
CardCode code = read_u32();
uint32_t loc = read_u32();
Card card = c_get_card(code);
card.set_location(loc);
auto desc = read_u32();
Card c = c_get_card(code);
auto pl = players_[player];
spec = card.get_spec(player);
auto name = card.name_;
auto name = c.name_;
std::string s;
if (desc == 0) {
// From [%ls], activate [%ls]?
s = "From " + card.get_spec(player) + ", activate " + name + "?";
} else if (desc < 2048) {
if (code_d == 0) {
s = get_system_string(desc);
std::string fmt_str = "[%ls]";
auto pos = find_substrs(s, fmt_str);
......@@ -4295,87 +4551,74 @@ private:
} else if (pos.size() == 2) {
auto p1 = pos[0];
auto p2 = pos[1];
s = s.substr(0, p1) + card.get_spec(player) +
s = s.substr(0, p1) + spec +
s.substr(p1 + fmt_str.size(), p2 - p1 - fmt_str.size()) + name +
s.substr(p2 + fmt_str.size());
} else {
throw std::runtime_error("Unknown effectyn desc " +
std::to_string(desc) + " of " + name);
}
} else if (desc < 10000u) {
s = get_system_string(desc);
} else {
CardCode code = (desc >> 4) & 0x0fffffff;
uint32_t offset = desc & 0xf;
if (cards_.find(code) != cards_.end()) {
auto &card_ = c_get_card(code);
s = card_.strings_[offset];
if (s.empty()) {
s = "???";
}
} else {
throw std::runtime_error("Unknown effectyn desc " +
std::to_string(desc) + " of " + name);
}
s = fmt::format(
"{}({}) ({})", c.name_, spec, c.get_effect_description(code_d, eff_idx));
}
pl->notify(s);
pl->notify("Please enter y or n.");
} else {
dp_ += 4;
auto c = read_u8();
auto loc = read_u8();
auto seq = read_u8();
auto pos = read_u8();
dp_ += 4;
spec = ls_to_spec(loc, seq, pos, c != player);
pl->notify("1: " + s);
pl->notify("2: No");
}
options_ = {"y " + spec, "n " + spec};
// TODO: maybe add card info to cancel
legal_actions_.push_back(LegalAction::cancel());
to_play_ = player;
callback_ = [this](int idx) {
if (idx == 0) {
YGO_SetResponsei(pduel_, 1);
} else if (idx == 1) {
YGO_SetResponsei(pduel_, 0);
} else {
throw std::runtime_error("Invalid option");
}
};
} else if (msg_ == MSG_SELECT_OPTION) {
// TODO: add card information
auto player = read_u8();
auto size = read_u8();
if (verbose_) {
auto pl = players_[player];
pl->notify("Select an option:");
for (int i = 0; i < size; ++i) {
auto opt = read_u32();
players_[player]->notify("Select an option:");
}
for (int i = 0; i < size; ++i) {
auto desc = read_u32();
auto [code, eff_idx] = unpack_desc(0, desc);
if (desc == 0) {
show_buffer();
auto s = fmt::format("Unknown desc {} in select_option", desc);
throw std::runtime_error(s);
}
auto la = LegalAction::activate_spec(eff_idx, "");
if (code != 0) {
la.cid_ = c_get_card_id(code);
}
legal_actions_.push_back(la);
if (verbose_) {
std::string s;
if (opt > 10000) {
CardCode code = opt >> 4;
s = c_get_card(code).strings_[opt & 0xf];
if (code == 0) {
s = get_system_string(eff_idx);
} else {
s = get_system_string(opt);
Card c = c_get_card(code);
int cmd_idx = legal_actions_.size();
eff_idx -= CARD_EFFECT_OFFSET;
if (eff_idx >= c.strings_.size()) {
throw std::runtime_error(
fmt::format("Unknown effect {} of {}", eff_idx, c.name_));
}
auto str = c.strings_[eff_idx];
if (str.empty()) {
str = "effect " + std::to_string(eff_idx);
}
s = fmt::format("{} ({})", c.name_, str);
}
std::string option = std::to_string(i + 1);
options_.push_back(option);
pl->notify(option + ": " + s);
}
} else {
for (int i = 0; i < size; ++i) {
dp_ += 4;
options_.push_back(std::to_string(i + 1));
players_[player]->notify(std::to_string(i + 1) + ": " + s);
}
}
to_play_ = player;
callback_ = [this](int idx) {
if (verbose_) {
players_[to_play_]->notify("You selected option " + options_[idx] +
".");
players_[1 - to_play_]->notify(players_[to_play_]->nickname_ +
" selected option " + options_[idx] +
".");
}
YGO_SetResponsei(pduel_, idx);
};
} else if (msg_ == MSG_SELECT_IDLECMD) {
......@@ -4397,90 +4640,97 @@ private:
pl->notify("Select a card and action to perform.");
}
for (const auto &[code, spec, data] : summonable_) {
std::string option = "s " + spec;
options_.push_back(option);
legal_actions_.push_back(LegalAction::act_spec(ActionAct::Summon, spec));
if (verbose_) {
const auto &name = c_get_card(code).name_;
pl->notify(option + ": Summon " + name +
" in face-up attack position.");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Summon {} in face-up attack position", cmd_idx, name));
}
}
offset += summonable_.size();
int spsummon_offset = offset;
for (const auto &[code, spec, data] : spsummon_) {
std::string option = "c " + spec;
options_.push_back(option);
legal_actions_.push_back(LegalAction::act_spec(ActionAct::SpSummon, spec));
if (verbose_) {
const auto &name = c_get_card(code).name_;
pl->notify(option + ": Special summon " + name + ".");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Special summon {}", cmd_idx, name));
}
}
offset += spsummon_.size();
int repos_offset = offset;
for (const auto &[code, spec, data] : repos_) {
std::string option = "r " + spec;
options_.push_back(option);
legal_actions_.push_back(LegalAction::act_spec(ActionAct::Repo, spec));
if (verbose_) {
const auto &name = c_get_card(code).name_;
pl->notify(option + ": Reposition " + name + ".");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Change position of {}", cmd_idx, name));
}
}
offset += repos_.size();
int mset_offset = offset;
for (const auto &[code, spec, data] : idle_mset_) {
std::string option = "m " + spec;
options_.push_back(option);
legal_actions_.push_back(LegalAction::act_spec(ActionAct::MSet, spec));
if (verbose_) {
const auto &name = c_get_card(code).name_;
pl->notify(option + ": Summon " + name +
" in face-down defense position.");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Summon {} in face-down defense position", cmd_idx, name));
}
}
offset += idle_mset_.size();
int set_offset = offset;
for (const auto &[code, spec, data] : idle_set_) {
std::string option = "t " + spec;
options_.push_back(option);
legal_actions_.push_back(LegalAction::act_spec(ActionAct::Set, spec));
if (verbose_) {
const auto &name = c_get_card(code).name_;
pl->notify(option + ": Set " + name + ".");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format(
"{}: Set {}", cmd_idx, name));
}
}
offset += idle_set_.size();
int activate_offset = offset;
ankerl::unordered_dense::map<std::string, int> idle_activate_count;
for (const auto &[code, spec, data] : idle_activate_) {
idle_activate_count[spec] += 1;
}
ankerl::unordered_dense::map<std::string, int> activate_count;
for (const auto &[code, spec, data] : idle_activate_) {
// TODO: use effect description to indicate which effect to activate
std::string option = "v " + spec;
int count = idle_activate_count[spec];
activate_count[spec]++;
if (count > 1) {
option.push_back('a' + activate_count[spec] - 1);
for (const auto &[code_t, spec, desc] : idle_activate_) {
CardCode code = code_t;
if(code & 0x80000000) {
code &= 0x7fffffff;
}
options_.push_back(option);
auto [code_d, eff_idx] = unpack_desc(code, desc);
if (desc == 0) {
code_d = code;
}
auto la = LegalAction::activate_spec(eff_idx, spec);
if (code_d != 0) {
la.cid_ = c_get_card_id(code_d);
}
legal_actions_.push_back(la);
if (verbose_) {
pl->notify(option + ": " +
c_get_card(code).get_effect_description(data));
auto c = c_get_card(code);
int cmd_idx = legal_actions_.size();
std::string s = fmt::format(
"{}: Activate {}({}) ({})",
cmd_idx, c.name_, spec, c.get_effect_description(code_d, eff_idx));
pl->notify(s);
}
}
if (to_bp_) {
std::string cmd = "b";
options_.push_back(cmd);
legal_actions_.push_back(LegalAction::phase(ActionPhase::Battle));
if (verbose_) {
pl->notify(cmd + ": Enter the battle phase.");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format("{}: Enter the battle phase.", cmd_idx));
}
}
if (to_ep_) {
if (!to_bp_) {
std::string cmd = "e";
options_.push_back(cmd);
legal_actions_.push_back(LegalAction::phase(ActionPhase::End));
if (verbose_) {
pl->notify(cmd + ": End phase.");
int cmd_idx = legal_actions_.size();
pl->notify(fmt::format("{}: End phase.", cmd_idx));
}
}
}
......@@ -4488,104 +4738,90 @@ private:
to_play_ = player;
callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset,
activate_offset](int idx) {
const auto &option = options_[idx];
char cmd = option[0];
if (cmd == 'b') {
const auto &action = legal_actions_[idx];
if (action.phase_ == ActionPhase::Battle) {
YGO_SetResponsei(pduel_, 6);
} else if (cmd == 'e') {
} else if (action.phase_ == ActionPhase::End) {
YGO_SetResponsei(pduel_, 7);
} else {
auto spec = option.substr(2);
if (cmd == 's') {
auto act = action.act_;
if (act == ActionAct::Summon) {
uint32_t idx_ = idx;
YGO_SetResponsei(pduel_, idx_ << 16);
} else if (cmd == 'c') {
} else if (act == ActionAct::SpSummon) {
uint32_t idx_ = idx - spsummon_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 1);
} else if (cmd == 'r') {
} else if (act == ActionAct::Repo) {
uint32_t idx_ = idx - repos_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 2);
} else if (cmd == 'm') {
} else if (act == ActionAct::MSet) {
uint32_t idx_ = idx - mset_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 3);
} else if (cmd == 't') {
} else if (act == ActionAct::Set) {
uint32_t idx_ = idx - set_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 4);
} else if (cmd == 'v') {
} else if (act == ActionAct::Activate) {
uint32_t idx_ = idx - activate_offset;
YGO_SetResponsei(pduel_, (idx_ << 16) + 5);
} else {
throw std::runtime_error("Invalid option: " + option);
}
}
};
} else if (msg_ == MSG_SELECT_PLACE) {
} else if (msg_ == MSG_SELECT_PLACE || msg_ == MSG_SELECT_DISFIELD) {
// TODO(1): add card informaton to select place
auto player = read_u8();
auto count = read_u8();
if (count == 0) {
count = 1;
}
auto flag = read_u32();
options_ = flag_to_usable_cardspecs(flag);
if (verbose_) {
std::string specs_str = options_[0];
for (int i = 1; i < options_.size(); ++i) {
specs_str += ", " + options_[i];
}
if (count == 1) {
players_[player]->notify("Select place for card, one of " +
specs_str + ".");
} else {
players_[player]->notify("Select " + std::to_string(count) +
" places for card, from " + specs_str + ".");
}
}
to_play_ = player;
callback_ = [this, player](int idx) {
int y = player + 1;
std::string spec = options_[idx];
auto plr = player;
if (spec[0] == 'o') {
plr = 1 - player;
spec = spec.substr(1);
}
auto [loc, seq, pos] = spec_to_ls(spec);
resp_buf_[0] = plr;
resp_buf_[1] = loc;
resp_buf_[2] = seq;
YGO_SetResponseb(pduel_, resp_buf_);
};
} else if (msg_ == MSG_SELECT_DISFIELD) {
auto player = read_u8();
auto count = read_u8();
if (count == 0) {
count = 1;
if (count != 1) {
auto s = fmt::format("Select place count {} not implemented for {}",
count, msg_ == MSG_SELECT_PLACE ? "place" : "disfield");
throw std::runtime_error(s);
}
auto flag = read_u32();
options_ = flag_to_usable_cardspecs(flag);
auto places = flag_to_usable_places(flag);
if (verbose_) {
std::string specs_str = options_[0];
for (int i = 1; i < options_.size(); ++i) {
specs_str += ", " + options_[i];
}
if (count == 1) {
players_[player]->notify("Select place for card, one of " +
specs_str + ".");
} else {
throw std::runtime_error("Select disfield count " +
std::to_string(count) + " not implemented");
auto place_s = msg_ == MSG_SELECT_PLACE ? "place" : "disfield";
auto s = fmt::format("Select {} for card, one of:", place_s);
players_[player]->notify(s);
}
for (int i = 0; i < places.size(); ++i) {
legal_actions_.push_back(LegalAction::place(places[i]));
if (verbose_) {
auto s = fmt::format("{}: {}", i + 1, action_place_to_string(places[i]));
players_[player]->notify(s);
}
}
to_play_ = player;
callback_ = [this, player](int idx) {
int y = player + 1;
std::string spec = options_[idx];
auto plr = player;
if (spec[0] == 'o') {
auto place = legal_actions_[idx].place_;
int i = static_cast<int>(place);
uint8_t plr = player;
uint8_t loc;
uint8_t seq;
if (
i >= static_cast<int>(ActionPlace::MZone1) &&
i <= static_cast<int>(ActionPlace::MZone7)) {
loc = LOCATION_MZONE;
seq = i - static_cast<int>(ActionPlace::MZone1);
} else if (
i >= static_cast<int>(ActionPlace::SZone1) &&
i <= static_cast<int>(ActionPlace::SZone8)) {
loc = LOCATION_SZONE;
seq = i - static_cast<int>(ActionPlace::SZone1);
} else if (
i >= static_cast<int>(ActionPlace::OpMZone1) &&
i <= static_cast<int>(ActionPlace::OpMZone7)) {
plr = 1 - player;
spec = spec.substr(1);
loc = LOCATION_MZONE;
seq = i - static_cast<int>(ActionPlace::OpMZone1);
} else if (
i >= static_cast<int>(ActionPlace::OpSZone1) &&
i <= static_cast<int>(ActionPlace::OpSZone8)) {
plr = 1 - player;
loc = LOCATION_SZONE;
seq = i - static_cast<int>(ActionPlace::OpSZone1);
}
auto [loc, seq, pos] = spec_to_ls(spec);
resp_buf_[0] = plr;
resp_buf_[1] = loc;
resp_buf_[2] = seq;
......@@ -4620,7 +4856,7 @@ private:
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
}
// TODO: implement action
// TODO(2): implement action
n_counters_ = count;
uint16_t resp1 = static_cast<uint16_t>(std::min(counter_count, counters[0]));
memcpy(resp_buf_, &resp1, 2);
......@@ -4644,19 +4880,15 @@ private:
" not implemented for announce number");
}
numbers.push_back(number);
options_.push_back(std::string(1, '0' + number));
legal_actions_.push_back(LegalAction::number(number));
}
if (verbose_) {
auto pl = players_[player];
std::string str = "Select a number, one of: [";
std::string str = "Select a number, one of:";
pl->notify(str);
for (int i = 0; i < count; ++i) {
str += std::to_string(numbers[i]);
if (i < count - 1) {
str += ", ";
}
pl->notify(fmt::format("{}: {}", i + 1, numbers[i]));
}
str += "]";
pl->notify(str);
}
to_play_ = player;
callback_ = [this](int idx) {
......@@ -4675,7 +4907,7 @@ private:
attrs.push_back(i + 1);
}
}
// TODO(2): implement action
if (count != 1) {
throw std::runtime_error("Announce attrib count " +
std::to_string(count) + " not implemented");
......@@ -4686,40 +4918,28 @@ private:
pl->notify("Select " + std::to_string(count) +
" attributes separated by spaces:");
for (int i = 0; i < attrs.size(); i++) {
pl->notify(std::to_string(attrs[i]) + ": " +
attribute_to_string(1 << (attrs[i] - 1)));
pl->notify(fmt::format("{}: {}", i + 1, attribute_to_string(1 << (attrs[i] - 1))));
}
}
auto combs = combinations(attrs.size(), count);
for (const auto &comb : combs) {
std::string option = "";
for (int j = 0; j < count; ++j) {
option += std::to_string(attrs[comb[j]]);
if (j < count - 1) {
option += " ";
}
}
options_.push_back(option);
// auto combs = combinations(attrs.size(), count);
for (int i = 0; i < attrs.size(); i++) {
legal_actions_.push_back(LegalAction::attribute(1 << (attrs[i] - 1)));
}
to_play_ = player;
callback_ = [this](int idx) {
const auto &option = options_[idx];
const auto &action = legal_actions_[idx];
uint32_t resp = 0;
int i = 0;
while (i < option.size()) {
resp |= 1 << (option[i] - '1');
i += 2;
}
resp |= action.attribute_;
YGO_SetResponsei(pduel_, resp);
};
} else if (msg_ == MSG_SELECT_POSITION) {
// TODO: add card as feature
auto player = read_u8();
auto code = read_u32();
auto valid_pos = read_u8();
CardId cid = c_get_card_id(code);
if (verbose_) {
auto pl = players_[player];
......@@ -4727,25 +4947,25 @@ private:
pl->notify("Select position for " + card.name_ + ":");
}
std::vector<uint8_t> positions;
int i = 1;
for (auto pos : {POS_FACEUP_ATTACK, POS_FACEDOWN_ATTACK,
POS_FACEUP_DEFENSE, POS_FACEDOWN_DEFENSE}) {
if (valid_pos & pos) {
positions.push_back(pos);
options_.push_back(std::to_string(i));
LegalAction la;
la.cid_ = cid;
la.position_ = pos;
legal_actions_.push_back(la);
int cmd_idx = legal_actions_.size();
if (verbose_) {
auto pl = players_[player];
pl->notify(fmt::format("{}: {}", i, position_to_string(pos)));
pl->notify(fmt::format("{}: {}", cmd_idx, position_to_string(pos)));
}
}
i++;
}
to_play_ = player;
callback_ = [this](int idx) {
uint8_t pos = options_[idx][0] - '1';
YGO_SetResponsei(pduel_, 1 << pos);
uint8_t pos = legal_actions_[idx].position_;
YGO_SetResponsei(pduel_, pos);
};
} else {
show_deck(0);
......@@ -4794,4 +5014,52 @@ using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
} // namespace ygopro
template <>
struct fmt::formatter<ygopro::LegalAction>: formatter<string_view> {
// Format the LegalAction object
template <typename FormatContext>
auto format(const ygopro::LegalAction& action, FormatContext& ctx) const {
std::stringstream ss;
ss << "{";
if (!action.spec_.empty()) {
ss << "spec='" << action.spec_ << "', ";
}
if (action.cid_ != 0) {
ss << "cid=" << action.cid_ << ", ";
}
if (action.act_ != ygopro::ActionAct::None) {
ss << "act=" << ygopro::action_act_to_string(action.act_) << ", ";
}
if (action.phase_ != ygopro::ActionPhase::None) {
ss << "phase=" << ygopro::action_phase_to_string(action.phase_) << ", ";
}
if (action.finish_) {
ss << "finish=true, ";
}
if (action.position_ != 0) {
ss << "position=" << ygopro::position_to_string(action.position_) << ", ";
}
if (action.effect_ != -1) {
ss << "effect=" << action.effect_ << ", ";
}
if (action.number_ != 0) {
ss << "number=" << int(action.number_) << ", ";
}
if (action.place_ != ygopro::ActionPlace::None) {
ss << "place=" << ygopro::action_place_to_string(action.place_) << ", ";
}
if (action.attribute_ != 0) {
ss << "attribute=" << ygopro::attribute_to_string(action.attribute_) << ", ";
}
std::string s = ss.str();
if (s.back() == ' ') {
s.pop_back();
s.pop_back();
}
s.push_back('}');
return format_to(ctx.out(), "{}", s);
}
};
#endif // YGOENV_YGOPRO_YGOPRO_H_
from ygoenv.python.api import py_env
from .ygopro0_ygoenv import (
_YGOPro0EnvPool,
_YGOPro0EnvSpec,
init_module,
)
(
YGOPro0EnvSpec,
YGOPro0DMEnvPool,
YGOPro0GymEnvPool,
YGOPro0GymnasiumEnvPool,
) = py_env(_YGOPro0EnvSpec, _YGOPro0EnvPool)
__all__ = [
"YGOPro0EnvSpec",
"YGOPro0DMEnvPool",
"YGOPro0GymEnvPool",
"YGOPro0GymnasiumEnvPool",
]
from ygoenv.registration import register
register(
task_id="YGOPro-v0",
import_path="ygoenv.ygopro0",
spec_cls="YGOPro0EnvSpec",
dm_cls="YGOPro0DMEnvPool",
gym_cls="YGOPro0GymEnvPool",
gymnasium_cls="YGOPro0GymnasiumEnvPool",
)
#include "ygoenv/ygopro0/ygopro.h"
#include "ygoenv/core/py_envpool.h"
using YGOPro0EnvSpec = PyEnvSpec<ygopro0::YGOProEnvSpec>;
using YGOPro0EnvPool = PyEnvPool<ygopro0::YGOProEnvPool>;
PYBIND11_MODULE(ygopro0_ygoenv, m) {
REGISTER(m, YGOPro0EnvSpec, YGOPro0EnvPool)
m.def("init_module", &ygopro0::init_module);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment