Commit b5190a73 authored by biluo.shen's avatar biluo.shen

(WIP) implement multi_select with multiple action

parent 58379501
......@@ -42,8 +42,8 @@
- is_end: 1, int, 0: False, 1: True
## Legal Actions (max 8)
- spec index: 8, int, select target
## Legal Actions (max 24)
- spec index: 2, int, select target
- msg: 1, int (16)
- act: 1, int (11)
- N/A
......@@ -66,12 +66,15 @@
- Battle (b)
- Main Phase 2 (m)
- End Phase (e)
- cancel_finish: 1, int (3)
- cancel: 1
- N/A
- Cancel
- finish: 1
- N/A
- Finish
- position: 1, int , 0: N/A, same as position2str
- option: 1, int, 0: N/A
- number: 1, int, 0: N/A
- place: 1, int (31), 0: N/A,
- 1-7: m
- 8-15: s
......
......@@ -1235,11 +1235,11 @@ public:
"play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"max_multi_select"_.Bind(5), "record"_.Bind(false));
"record"_.Bind(false));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
int n_action_feats = 10 + conf["max_multi_select"_] * 2;
int n_action_feats = 13;
return MakeDict(
"obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"obs:global_"_.Bind(Spec<uint8_t>({23})),
......@@ -1370,15 +1370,24 @@ protected:
// circular buffer for history actions of player 0
TArray<uint8_t> history_actions_0_;
int ha_p_0_ = 0;
std::vector<std::vector<CardId>> h_card_ids_0_;
std::vector<CardId> h_card_ids_0_;
// circular buffer for history actions of player 1
TArray<uint8_t> history_actions_1_;
int ha_p_1_ = 0;
std::vector<std::vector<CardId>> h_card_ids_1_;
std::vector<CardId> h_card_ids_1_;
std::unordered_set<std::string> revealed_;
// multi select
int ms_idx_ = -1;
int ms_min_ = 0;
int ms_max_ = 0;
std::vector<std::string> ms_specs_;
ankerl::unordered_dense::map<std::string, int> ms_spec2idx_;
std::vector<int> ms_r_idxs_;
// discard hand cards
bool discard_hand_ = false;
......@@ -1582,7 +1591,7 @@ public:
void update_h_card_ids(PlayerId player, int idx) {
auto &h_card_ids = player == 0 ? h_card_ids_0_ : h_card_ids_1_;
h_card_ids[idx] = parse_card_ids(options_[idx], player);
h_card_ids[idx] = parse_card_id(options_[idx], player);
}
void update_history_actions(PlayerId player, int idx) {
......@@ -1623,17 +1632,14 @@ public:
// print card ids of history actions
for (int i = 0; i < n_history_actions_; ++i) {
fmt::print("history {}\n", i);
uint8_t msg_id = uint8_t(ha(i, _obs_action_feat_offset()));
uint8_t msg_id = uint8_t(ha(i, 2));
int msg = _msgs[msg_id - 1];
fmt::print("msg: {},", msg_to_string(msg));
for (int j = 0; j < spec_.config["max_multi_select"_]; j++) {
auto v1 = static_cast<CardId>(ha(i, 2 * j));
auto v2 = static_cast<CardId>(ha(i, 2 * j + 1));
auto v1 = static_cast<CardId>(ha(i, 0));
auto v2 = static_cast<CardId>(ha(i, 1));
CardId card_id = (v1 << 8) + v2;
fmt::print(" {}", card_id);
}
fmt::print(";");
for (int j = _obs_action_feat_offset() + 1; j < ha.Shape()[1]; j++) {
fmt::print(" {};", card_id);
for (int j = 3; j < ha.Shape()[1]; j++) {
fmt::print(" {}", uint8_t(ha(i, j)));
}
fmt::print("\n");
......@@ -1653,7 +1659,36 @@ public:
show_decision(idx);
}
if (ms_idx_ != -1) {
options_ = {};
for (int j = 0; j < ms_specs_.size(); ++j) {
if (ms_spec2idx_.find(ms_specs_[j]) != ms_spec2idx_.end()) {
options_.push_back(ms_specs_[j]);
}
}
int midx = ms_idx_ + 1;
if (midx >= ms_min_ && midx < ms_max_) {
options_.push_back("f");
callback_ = [this](int idx) {
const auto &option = options_[idx];
if (option[0] == 'f') {
ms_idx_ = -1;
resp_buf_[0] = ms_r_idxs_.size();
for (int i = 0; i < ms_r_idxs_.size(); ++i) {
resp_buf_[i + 1] = ms_r_idxs_[i];
}
YGO_SetResponseb(pduel_, resp_buf_);
} else {
idx = ms_spec2idx_.at(option);
ms_idx_++;
ms_r_idxs_.push_back(idx);
ms_spec2idx_.erase(ms_specs_[idx]);
}
};
}
} else {
next();
}
float reward = 0;
int reason = 0;
......@@ -1846,13 +1881,13 @@ private:
}
}
void _set_obs_action_spec(TArray<uint8_t> &feat, int i, int j,
void _set_obs_action_spec(TArray<uint8_t> &feat, int i,
const std::string &spec,
const SpecIndex &spec2index,
const std::vector<CardId> &card_ids) {
CardId card_id = 0) {
uint16_t idx;
if (spec2index.empty()) {
idx = card_ids[j];
idx = card_id;
} else {
auto it = spec2index.find(spec);
if (it == spec2index.end()) {
......@@ -1868,61 +1903,59 @@ private:
idx = it->second;
}
}
feat(i, 2 * j) = static_cast<uint8_t>(idx >> 8);
feat(i, 2 * j + 1) = static_cast<uint8_t>(idx & 0xff);
}
int _obs_action_feat_offset() const {
return spec_.config["max_multi_select"_] * 2;
feat(i, 0) = static_cast<uint8_t>(idx >> 8);
feat(i, 1) = static_cast<uint8_t>(idx & 0xff);
}
void _set_obs_action_msg(TArray<uint8_t> &feat, int i, int msg) {
feat(i, _obs_action_feat_offset()) = msg2id.at(msg);
feat(i, 2) = msg2id.at(msg);
}
void _set_obs_action_act(TArray<uint8_t> &feat, int i, char act,
uint8_t act_offset = 0) {
feat(i, _obs_action_feat_offset() + 1) = cmd_act2id.at(act) + act_offset;
feat(i, 3) = cmd_act2id.at(act) + act_offset;
}
void _set_obs_action_yesno(TArray<uint8_t> &feat, int i, char yesno) {
feat(i, _obs_action_feat_offset() + 2) = cmd_yesno2id.at(yesno);
feat(i, 4) = cmd_yesno2id.at(yesno);
}
void _set_obs_action_phase(TArray<uint8_t> &feat, int i, char phase) {
feat(i, _obs_action_feat_offset() + 3) = cmd_phase2id.at(phase);
feat(i, 5) = cmd_phase2id.at(phase);
}
void _set_obs_action_cancel(TArray<uint8_t> &feat, int i) {
feat(i, 6) = 1;
}
void _set_obs_action_cancel_finish(TArray<uint8_t> &feat, int i, char c) {
uint8_t v = c == 'c' ? 1 : (c == 'f' ? 2 : 0);
feat(i, _obs_action_feat_offset() + 4) = v;
void _set_obs_action_finish(TArray<uint8_t> &feat, int i) {
feat(i, 7) = 1;
}
void _set_obs_action_position(TArray<uint8_t> &feat, int i, char position) {
position = 1 << (position - '1');
feat(i, _obs_action_feat_offset() + 5) = position2id.at(position);
feat(i, 8) = position2id.at(position);
}
void _set_obs_action_option(TArray<uint8_t> &feat, int i, char option) {
feat(i, _obs_action_feat_offset() + 6) = option - '0';
feat(i, 9) = option - '0';
}
void _set_obs_action_number(TArray<uint8_t> &feat, int i, char number) {
feat(i, _obs_action_feat_offset() + 7) = number - '0';
feat(i, 10) = number - '0';
}
void _set_obs_action_place(TArray<uint8_t> &feat, int i,
const std::string &spec) {
feat(i, _obs_action_feat_offset() + 8) = cmd_place2id.at(spec);
void _set_obs_action_place(TArray<uint8_t> &feat, int i, const std::string &spec) {
feat(i, 11) = cmd_place2id.at(spec);
}
void _set_obs_action_attrib(TArray<uint8_t> &feat, int i, uint8_t attrib) {
feat(i, _obs_action_feat_offset() + 9) = attribute2id.at(attrib);
feat(i, 12) = attribute2id.at(attrib);
}
void _set_obs_action(TArray<uint8_t> &feat, int i, int msg,
const std::string &option, const SpecIndex &spec2index,
const std::vector<CardId> &card_ids) {
CardId card_id) {
_set_obs_action_msg(feat, i, msg);
if (msg == MSG_SELECT_IDLECMD) {
if (option == "b" || option == "e") {
......@@ -1938,11 +1971,11 @@ private:
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, 0, spec, spec2index, card_ids);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CHAIN) {
if (option[0] == 'c') {
_set_obs_action_cancel_finish(feat, i, option[0]);
_set_obs_action_cancel(feat, i);
} else {
char act = 'v';
auto spec = option;
......@@ -1954,42 +1987,20 @@ private:
}
_set_obs_action_act(feat, i, act, offset);
_set_obs_action_spec(feat, i, 0, spec, spec2index, card_ids);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_CARD || msg == MSG_SELECT_TRIBUTE ||
msg == MSG_SELECT_SUM) {
if (spec2index.empty()) {
for (int k = 0; k < card_ids.size(); ++k) {
_set_obs_action_spec(feat, i, k, option, spec2index, card_ids);
}
} else {
int k = 0;
size_t start = 0;
while (start < option.size()) {
size_t idx = option.find_first_of(" ", start);
if (idx == std::string::npos) {
auto spec = option.substr(start);
_set_obs_action_spec(feat, i, k, spec, spec2index, {});
break;
} else {
auto spec = option.substr(start, idx - start);
_set_obs_action_spec(feat, i, k, spec, spec2index, {});
k++;
start = idx + 1;
}
}
}
} else if (msg == MSG_SELECT_UNSELECT_CARD) {
msg == MSG_SELECT_SUM || msg == MSG_SELECT_UNSELECT_CARD) {
if (option[0] == 'f') {
_set_obs_action_cancel_finish(feat, i, option[0]);
_set_obs_action_finish(feat, i);
} else {
_set_obs_action_spec(feat, i, 0, option, spec2index, card_ids);
_set_obs_action_spec(feat, i, option, spec2index, card_id);
}
} else if (msg == MSG_SELECT_POSITION) {
_set_obs_action_position(feat, i, option[0]);
} else if (msg == MSG_SELECT_EFFECTYN) {
auto spec = option.substr(2);
_set_obs_action_spec(feat, i, 0, spec, spec2index, card_ids);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
_set_obs_action_yesno(feat, i, option[0]);
} else if (msg == MSG_SELECT_YESNO) {
......@@ -2001,7 +2012,7 @@ private:
auto act = option[0];
auto spec = option.substr(2);
_set_obs_action_act(feat, i, act);
_set_obs_action_spec(feat, i, 0, spec, spec2index, card_ids);
_set_obs_action_spec(feat, i, spec, spec2index, card_id);
}
} else if (msg == MSG_SELECT_OPTION) {
_set_obs_action_option(feat, i, option[0]);
......@@ -2018,6 +2029,7 @@ private:
CardId spec_to_card_id(const std::string &spec, PlayerId player) {
int offset = 0;
// TODO: possible info leak
if (spec[0] == 'o') {
player = 1 - player;
offset++;
......@@ -2026,54 +2038,40 @@ private:
return card_ids_.at(get_card_code(player, loc, seq));
}
std::vector<CardId> parse_card_ids(const std::string &option,
PlayerId player) {
std::vector<CardId> card_ids;
CardId parse_card_id(const std::string &option, PlayerId player) {
CardId card_id = 0;
if (msg_ == MSG_SELECT_IDLECMD) {
if (!(option == "b" || option == "e")) {
auto n = option.size();
if (std::isalpha(option[n - 1])) {
card_ids.push_back(spec_to_card_id(option.substr(2, n - 3), player));
card_id = spec_to_card_id(option.substr(2, n - 3), player);
} else {
card_ids.push_back(spec_to_card_id(option.substr(2), player));
card_id = spec_to_card_id(option.substr(2), player);
}
}
} else if (msg_ == MSG_SELECT_CHAIN) {
if (option != "c") {
card_ids.push_back(spec_to_card_id(option, player));
card_id = spec_to_card_id(option, player);
}
} else if (msg_ == MSG_SELECT_CARD || msg_ == MSG_SELECT_TRIBUTE ||
msg_ == MSG_SELECT_SUM) {
size_t start = 0;
while (start < option.size()) {
size_t idx = option.find_first_of(" ", start);
if (idx == std::string::npos) {
card_ids.push_back(spec_to_card_id(option.substr(start), player));
break;
} else {
card_ids.push_back(
spec_to_card_id(option.substr(start, idx - start), player));
start = idx + 1;
}
}
} else if (msg_ == MSG_SELECT_UNSELECT_CARD) {
msg_ == MSG_SELECT_SUM || msg_ == MSG_SELECT_UNSELECT_CARD) {
if (option[0] != 'f') {
card_ids.push_back(spec_to_card_id(option, player));
card_id = spec_to_card_id(option, player);
}
} else if (msg_ == MSG_SELECT_EFFECTYN) {
card_ids.push_back(spec_to_card_id(option.substr(2), player));
card_id = spec_to_card_id(option.substr(2), player);
} else if (msg_ == MSG_SELECT_BATTLECMD) {
if (!(option == "m" || option == "e")) {
card_ids.push_back(spec_to_card_id(option.substr(2), player));
card_id = spec_to_card_id(option.substr(2), player);
}
}
return card_ids;
return card_id;
}
void _set_obs_actions(TArray<uint8_t> &feat, const SpecIndex &spec2index,
int msg, const std::vector<std::string> &options) {
for (int i = 0; i < options.size(); ++i) {
_set_obs_action(feat, i, msg, options[i], spec2index, {});
_set_obs_action(feat, i, msg, options[i], spec2index, 0);
}
}
......@@ -2210,20 +2208,18 @@ private:
auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_;
for (int i = 0; i < n_options; ++i) {
std::vector<CardId> card_ids;
for (int j = 0; j < spec_.config["max_multi_select"_]; ++j) {
uint8_t spec_index = state["obs:actions_"_](i, 2 * j + 1);
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (spec_index1 << 8) + spec_index2;
if (spec_index == 0) {
break;
}
// because of na_card_embed, we need to subtract 1
h_card_ids[i] = 0;
} else {
uint16_t card_id1 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 0));
uint16_t card_id2 =
static_cast<uint16_t>(state["obs:cards_"_](spec_index - 1, 1));
card_ids.push_back((card_id1 << 8) + card_id2);
h_card_ids[i] = (card_id1 << 8) + card_id2;
}
h_card_ids[i] = card_ids;
}
// write history actions
......@@ -2302,11 +2298,13 @@ private:
}
YGO_GetMessage(pduel_, data_);
dp_ = 0;
while (dp_ != dl_) {
while ((dp_ != dl_) || (ms_idx_ != -1)) {
if (ms_idx_ == -1) {
handle_message();
if (options_.empty()) {
continue;
}
}
if ((play_mode_ == kSelfPlay) || (to_play_ == ai_player_)) {
if (options_.size() == 1) {
callback_(0);
......@@ -2549,6 +2547,10 @@ private:
return card.name_ + " (" + spec + ")";
}
// This function does the following:
// 1. read msg_ from data_ and update dp_
// 2. (optional) print information if verbose_ is true
// 3. update to_play_ and options_ if need action
void handle_message() {
msg_ = int(data_[dp_++]);
options_ = {};
......@@ -3545,6 +3547,10 @@ private:
auto max = read_u8();
auto size = read_u8();
if (min == 0) {
throw std::runtime_error("Min == 0 not implemented for select card");
}
std::vector<std::string> specs;
specs.reserve(size);
if (verbose_) {
......@@ -3580,7 +3586,6 @@ private:
}
}
if (min > spec_.config["max_multi_select"_]) {
if (discard_hand_) {
// random discard
std::vector<int> comb(size);
......@@ -3595,46 +3600,29 @@ private:
return;
}
show_turn();
ms_idx_ = 0;
ms_min_ = min;
ms_max_ = max;
ms_specs_ = specs;
ms_spec2idx_.clear();
show_deck(player);
show_history_actions(player);
show_deck(1-player);
show_history_actions(1-player);
fmt::println("player: {}, min: {}, max: {}, size: {}", player, min, max, size);
std::cout << std::flush;
throw std::runtime_error(
fmt::format("Min > {} not implemented for select card",
spec_.config["max_multi_select"_]));
}
max = std::min(max, uint8_t(spec_.config["max_multi_select"_]));
std::vector<std::vector<int>> combs;
for (int i = min; i <= max; ++i) {
for (const auto &comb : combinations(size, i)) {
combs.push_back(comb);
std::string option = "";
for (int j = 0; j < i; ++j) {
option += specs[comb[j]];
if (j < i - 1) {
option += " ";
}
}
options_.push_back(option);
}
for (int j = 0; j < ms_specs_.size(); ++j) {
const auto &spec = ms_specs_[j];
options_.push_back(spec);
ms_spec2idx_[spec] = j;
}
to_play_ = player;
callback_ = [this, combs](int idx) {
const auto &comb = combs[idx];
resp_buf_[0] = comb.size();
for (int i = 0; i < comb.size(); ++i) {
resp_buf_[i + 1] = comb[i];
}
callback_ = [this](int idx) {
if (ms_max_ == 1) {
ms_idx_ = -1;
resp_buf_[0] = 1;
resp_buf_[1] = static_cast<uint8_t>(idx);
YGO_SetResponseb(pduel_, resp_buf_);
}
ms_idx_++;
ms_r_idxs_.push_back(idx);
ms_spec2idx_.erase(ms_specs_[idx]);
};
} else if (msg_ == MSG_SELECT_TRIBUTE) {
auto player = read_u8();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment