Commit 389fcd9c authored by sbl1996@126.com's avatar sbl1996@126.com

fixing set_obs_cards seg fault

parent 46b4b5ae
SCRIPTS_REPO := "https://github.com/Fluorohydride/ygopro-scripts.git" SCRIPTS_REPO := "https://github.com/mycard/ygopro-scripts.git"
SCRIPTS_DIR := "../ygopro-scripts" SCRIPTS_DIR := "../ygopro-scripts"
DATABASE_REPO := "https://github.com/mycard/ygopro-database/raw/master/locales" DATABASE_REPO := "https://github.com/mycard/ygopro-database/raw/master/locales"
LOCALES := en zh LOCALES := en zh
......
...@@ -11,7 +11,7 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL). ...@@ -11,7 +11,7 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Building ## Building
### prerequisites ### Prerequisites
- gcc 10+ or clang 11+ - gcc 10+ or clang 11+
- [xmake](https://xmake.io/#/getting_started) - [xmake](https://xmake.io/#/getting_started)
- PyTorch 2.0 or later with cuda support - PyTorch 2.0 or later with cuda support
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
35059553 35059553
24224830 24224830
24224830 24224830
84211599 97268402
73628505 73628505
40155014 40155014
40155014 40155014
...@@ -58,3 +58,4 @@ ...@@ -58,3 +58,4 @@
02857636 02857636
!side !side
27204312 27204312
92907249
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
51697825 51697825
51697825 51697825
51697825 51697825
84211599 98645731
28126717 28126717
55521751 55521751
24224830 24224830
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
80433039 80433039
17827173 17827173
24508238 24508238
84211599 98645731
75500286 75500286
98645731 98645731
49238328 49238328
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
97268402 97268402
2511 2511
2511 2511
84211599 97268402
24224830 24224830
24224830 24224830
33407125 33407125
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
1984618 1984618
1984618 1984618
35261759 35261759
84211599 49238328
49238328 49238328
84797028 84797028
84797028 84797028
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
24508238 24508238
2295440 2295440
51405049 51405049
84211599 27204311
89023486 89023486
24224830 24224830
24224830 24224830
...@@ -58,3 +58,4 @@ ...@@ -58,3 +58,4 @@
94259633 94259633
!side !side
52340445 52340445
27204312
...@@ -840,3 +840,4 @@ ...@@ -840,3 +840,4 @@
35809262 35809262
92731385 92731385
74018812 74018812
92907249
...@@ -5,7 +5,6 @@ from collections import deque ...@@ -5,7 +5,6 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Literal, Optional
import ygoenv import ygoenv
import numpy as np import numpy as np
import tyro import tyro
...@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings ...@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_self from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate from ygoai.rl.eval import evaluate
...@@ -79,11 +78,6 @@ class Args: ...@@ -79,11 +78,6 @@ class Args:
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
...@@ -265,28 +259,21 @@ def main(): ...@@ -265,28 +259,21 @@ def main():
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) def predict_step(agent: Agent, next_obs):
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs) logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value return logits, value
from ygoai.rl.ppo import train_step from ygoai.rl.ppo import train_step
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile) # predict_step = torch.compile(predict_step, mode=args.compile)
agent = torch.compile(agent, mode=args.compile) obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
...@@ -298,7 +285,6 @@ def main(): ...@@ -298,7 +285,6 @@ def main():
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device) learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game # TRY NOT TO MODIFY: start the game
global_step = 0 global_step = 0
...@@ -315,7 +301,7 @@ def main(): ...@@ -315,7 +301,7 @@ def main():
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0 next_value1 = next_value2 = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -324,8 +310,6 @@ def main(): ...@@ -324,8 +310,6 @@ def main():
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
agent.eval()
model_time = 0 model_time = 0
env_time = 0 env_time = 0
collect_start = time.time() collect_start = time.time()
...@@ -339,7 +323,7 @@ def main(): ...@@ -339,7 +323,7 @@ def main():
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
logits, value = predict_step(agent, traced_model_t, next_obs, learn) logits, value = predict_step(traced_model, next_obs)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -352,7 +336,8 @@ def main(): ...@@ -352,7 +336,8 @@ def main():
model_time += time.time() - _start model_time += time.time() - _start
next_nonterminal = 1 - next_done.float() next_nonterminal = 1 - next_done.float()
next_value = torch.where(learn, value, next_value) * next_nonterminal next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time() _start = time.time()
to_play = next_to_play_ to_play = next_to_play_
...@@ -393,16 +378,14 @@ def main(): ...@@ -393,16 +378,14 @@ def main():
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = agent(next_obs)[1].reshape(-1) value = traced_model(next_obs)[1].reshape(-1)
value_t = traced_model_t(next_obs)[1].reshape(-1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
value = torch.where(next_to_play == ai_player1, value, value_t) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value) advantages = bootstrap_value_selfplay(
advantages = bootstrap_value_self( values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
returns = advantages + values returns = advantages + values
bootstrap_time = time.time() - _start bootstrap_time = time.time() - _start
agent.train()
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
b_obs = { b_obs = {
...@@ -475,31 +458,11 @@ def main(): ...@@ -475,31 +458,11 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0: if iteration % args.eval_interval == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return # Eval with rule-based policy
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
_start = time.time() _start = time.time()
agent.eval()
eval_return = evaluate( eval_return = evaluate(
eval_envs, agent, local_eval_episodes, device, args.fp16_eval) eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device) eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics # sync the statistics
......
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,6 +19,29 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24): ...@@ -17,6 +19,29 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0) intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0)
return points, intervals return points, intervals
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
...@@ -122,7 +147,6 @@ class Encoder(nn.Module): ...@@ -122,7 +147,6 @@ class Encoder(nn.Module):
nn.Linear(c, c), nn.Linear(c, c),
) )
self.h_id_fc_emb = linear(1024, c) self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False) self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False) self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
...@@ -134,6 +158,7 @@ class Encoder(nn.Module): ...@@ -134,6 +158,7 @@ class Encoder(nn.Module):
for i in range(num_action_layers) for i in range(num_action_layers)
]) ])
self.action_history_pe = PositionalEncoding(c, dropout=0.0)
self.action_history_net = nn.ModuleList([ self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer( nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False) c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
...@@ -322,6 +347,7 @@ class Encoder(nn.Module): ...@@ -322,6 +347,7 @@ class Encoder(nn.Module):
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:]) x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1) x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats) f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
for layer in self.action_history_net: for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions) f_actions = layer(f_actions, f_h_actions)
......
...@@ -108,7 +108,6 @@ def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, ...@@ -108,7 +108,6 @@ def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done,
def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda): def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
# TODO: drop epsilon steps for estimated nextvalues
num_steps = rewards.size(0) num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards) advantages = torch.zeros_like(rewards)
# TODO: optimize this # TODO: optimize this
......
...@@ -365,8 +365,7 @@ inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) { ...@@ -365,8 +365,7 @@ inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos) {
return spec; return spec;
} }
inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos, inline std::string ls_to_spec(uint8_t loc, uint8_t seq, uint8_t pos, bool opponent) {
bool opponent) {
std::string spec = ls_to_spec(loc, seq, pos); std::string spec = ls_to_spec(loc, seq, pos);
if (opponent) { if (opponent) {
spec.insert(0, 1, 'o'); spec.insert(0, 1, 'o');
...@@ -1471,6 +1470,7 @@ public: ...@@ -1471,6 +1470,7 @@ public:
int init_lp = 8000; int init_lp = 8000;
int startcount = 5; int startcount = 5;
int drawcount = 1; int drawcount = 1;
for (PlayerId i = 0; i < 2; i++) { for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) { if (players_[i] != nullptr) {
delete players_[i]; delete players_[i];
...@@ -1641,10 +1641,12 @@ public: ...@@ -1641,10 +1641,12 @@ public:
void Step(const Action &action) override { void Step(const Action &action) override {
// clock_t start = clock(); // clock_t start = clock();
// fmt::println("Step");
int idx = action["action"_]; int idx = action["action"_];
callback_(idx); callback_(idx);
update_history_actions(to_play_, idx); // update_history_actions(to_play_, idx);
// fmt::println("update_history_actions");
PlayerId player = to_play_; PlayerId player = to_play_;
...@@ -2153,6 +2155,10 @@ private: ...@@ -2153,6 +2155,10 @@ private:
ReplayWriteInt8(1); ReplayWriteInt8(1);
fwrite(buf, 1, 1, fp_); fwrite(buf, 1, 1, fp_);
break; break;
case MSG_SELECT_COUNTER:
ReplayWriteInt8(2);
fwrite(buf, 2, 1, fp_);
break;
case MSG_SELECT_PLACE: case MSG_SELECT_PLACE:
case MSG_SELECT_DISFIELD: case MSG_SELECT_DISFIELD:
ReplayWriteInt8(3); ReplayWriteInt8(3);
...@@ -2183,10 +2189,13 @@ private: ...@@ -2183,10 +2189,13 @@ private:
state["obs:global_"_][22] = uint8_t(1); state["obs:global_"_][22] = uint8_t(1);
return; return;
} }
// fmt::println("writestate");
auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_); auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
// fmt::println("_set_obs_cards");
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards); // _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// fmt::println("_set_obs_global");
// we can't shuffle because idx must be stable in callback // we can't shuffle because idx must be stable in callback
if (n_options > max_options()) { if (n_options > max_options()) {
...@@ -2198,10 +2207,12 @@ private: ...@@ -2198,10 +2207,12 @@ private:
// fmt::println("{} {}", key, val); // fmt::println("{} {}", key, val);
// } // }
_set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_); // _set_obs_actions(state["obs:actions_"_], spec2index, msg_, options_);
// fmt::println("_set_obs_actions");
n_options = options_.size(); n_options = options_.size();
state["info:num_options"_] = n_options; state["info:num_options"_] = n_options;
return;
// update h_card_ids from state // update h_card_ids from state
auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_; auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_;
...@@ -2222,6 +2233,7 @@ private: ...@@ -2222,6 +2233,7 @@ private:
} }
h_card_ids[i] = card_ids; h_card_ids[i] = card_ids;
} }
// fmt::println("update h_card_ids");
// write history actions // write history actions
...@@ -2235,6 +2247,7 @@ private: ...@@ -2235,6 +2247,7 @@ private:
n_action_feats * n1); n_action_feats * n1);
state["obs:h_actions_"_][n1].Assign((uint8_t *)history_actions.Data(), state["obs:h_actions_"_][n1].Assign((uint8_t *)history_actions.Data(),
n_action_feats * ha_p); n_action_feats * ha_p);
// fmt::println("write history actions");
} }
void show_decision(int idx) { void show_decision(int idx) {
...@@ -2307,8 +2320,8 @@ private: ...@@ -2307,8 +2320,8 @@ private:
if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) { if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) {
if (options_.size() == 1) { if (options_.size() == 1) {
callback_(0); callback_(0);
update_h_card_ids(to_play_, 0); // update_h_card_ids(to_play_, 0);
update_history_actions(to_play_, 0); // update_history_actions(to_play_, 0);
if (verbose_) { if (verbose_) {
show_decision(0); show_decision(0);
} }
...@@ -2488,6 +2501,7 @@ private: ...@@ -2488,6 +2501,7 @@ private:
} }
cards.push_back(c); cards.push_back(c);
} }
fmt::println("qdp: {}, bl: {}, n: {}", qdp_, bl, cards.size());
return cards; return cards;
} }
...@@ -2513,8 +2527,7 @@ private: ...@@ -2513,8 +2527,7 @@ private:
return cards; return cards;
} }
std::vector<IdleCardSpec> read_cardlist_spec(bool extra = false, std::vector<IdleCardSpec> read_cardlist_spec(PlayerId player, bool extra = false, bool extra8 = false) {
bool extra8 = false) {
std::vector<IdleCardSpec> card_specs; std::vector<IdleCardSpec> card_specs;
auto count = read_u8(); auto count = read_u8();
card_specs.reserve(count); card_specs.reserve(count);
...@@ -2531,7 +2544,7 @@ private: ...@@ -2531,7 +2544,7 @@ private:
data = read_u32(); data = read_u32();
} }
} }
card_specs.push_back({code, ls_to_spec(loc, seq, 0), data}); card_specs.push_back({code, ls_to_spec(loc, seq, 0, player != controller), data});
} }
return card_specs; return card_specs;
} }
...@@ -2988,6 +3001,7 @@ private: ...@@ -2988,6 +3001,7 @@ private:
if (verbose_) { if (verbose_) {
cards.push_back(get_card(c, loc, seq)); cards.push_back(get_card(c, loc, seq));
} }
// TODO: check if this is correct
revealed_.push_back(ls_to_spec(loc, seq, 0, c == player)); revealed_.push_back(ls_to_spec(loc, seq, 0, c == player));
} }
if (!verbose_) { if (!verbose_) {
...@@ -3406,8 +3420,8 @@ private: ...@@ -3406,8 +3420,8 @@ private:
throw std::runtime_error("Retry"); throw std::runtime_error("Retry");
} else if (msg_ == MSG_SELECT_BATTLECMD) { } else if (msg_ == MSG_SELECT_BATTLECMD) {
auto player = read_u8(); auto player = read_u8();
auto activatable = read_cardlist_spec(true); auto activatable = read_cardlist_spec(player, true);
auto attackable = read_cardlist_spec(true, true); auto attackable = read_cardlist_spec(player, true, true);
bool to_m2 = read_u8(); bool to_m2 = read_u8();
bool to_ep = read_u8(); bool to_ep = read_u8();
...@@ -4122,12 +4136,12 @@ private: ...@@ -4122,12 +4136,12 @@ private:
}; };
} else if (msg_ == MSG_SELECT_IDLECMD) { } else if (msg_ == MSG_SELECT_IDLECMD) {
int32_t player = read_u8(); int32_t player = read_u8();
auto summonable_ = read_cardlist_spec(); auto summonable_ = read_cardlist_spec(player);
auto spsummon_ = read_cardlist_spec(); auto spsummon_ = read_cardlist_spec(player);
auto repos_ = read_cardlist_spec(); auto repos_ = read_cardlist_spec(player);
auto idle_mset_ = read_cardlist_spec(); auto idle_mset_ = read_cardlist_spec(player);
auto idle_set_ = read_cardlist_spec(); auto idle_set_ = read_cardlist_spec(player);
auto idle_activate_ = read_cardlist_spec(true); auto idle_activate_ = read_cardlist_spec(player, true);
bool to_bp_ = read_u8(); bool to_bp_ = read_u8();
bool to_ep_ = read_u8(); bool to_ep_ = read_u8();
read_u8(); // can_shuffle read_u8(); // can_shuffle
...@@ -4332,6 +4346,35 @@ private: ...@@ -4332,6 +4346,35 @@ private:
resp_buf_[2] = seq; resp_buf_[2] = seq;
YGO_SetResponseb(pduel_, resp_buf_); YGO_SetResponseb(pduel_, resp_buf_);
}; };
} else if (msg_ == MSG_SELECT_COUNTER) {
auto player = read_u8();
auto counter_type = read_u16();
auto counter_count = read_u16();
int count = read_u8();
if (count != 1) {
throw std::runtime_error("Select counter count " +
std::to_string(count) + " not implemented");
}
auto pl = players_[player];
if (verbose_) {
pl->notify(fmt::format("Type new {} for {} card(s), separated by spaces.", "UNKNOWN_COUNTER", count));
}
for (int i = 0; i < count; ++i) {
auto code = read_u32();
auto controller = read_u8();
auto loc = read_u8();
auto seq = read_u8();
auto counter = read_u16();
if (verbose_) {
pl->notify(c_get_card(code).name_ + ": " + std::to_string(counter));
}
// auto spec = ls_to_spec(loc, seq, 0, controller != player);
// options_.push_back(spec);
}
// TODO: implement action
uint16_t resp = counter_count & 0xffff;
memcpy(resp_buf_, &resp, 2);
YGO_SetResponseb(pduel_, resp_buf_);
} else if (msg_ == MSG_ANNOUNCE_NUMBER) { } else if (msg_ == MSG_ANNOUNCE_NUMBER) {
auto player = read_u8(); auto player = read_u8();
int count = read_u8(); int count = read_u8();
...@@ -4448,6 +4491,11 @@ private: ...@@ -4448,6 +4491,11 @@ private:
} else { } else {
show_deck(0); show_deck(0);
show_deck(1); show_deck(1);
// print byte by byte
for (int i = 0; i < dp_; ++i) {
fmt::print("{:02x} ", data_[i]);
}
fmt::print("\n");
throw std::runtime_error( throw std::runtime_error(
fmt::format("Unknown message {}, length {}, dp {}", fmt::format("Unknown message {}, length {}, dp {}",
msg_to_string(msg_), dl_, dp_)); msg_to_string(msg_), dl_, dp_));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment