Commit c90562de authored by sbl1996@126.com's avatar sbl1996@126.com

Env fault tolerant training via timeout

parent b7e43382
...@@ -311,6 +311,7 @@ if __name__ == "__main__": ...@@ -311,6 +311,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones1): for idx, d in enumerate(dones1):
if not d or (args.accurate and collected[idx]): if not d or (args.accurate and collected[idx]):
continue continue
# c1 = collected[idx]
collected[idx] = True collected[idx] = True
win_reason = infos1['win_reason'][idx] win_reason = infos1['win_reason'][idx]
pl = 1 if main[idx] else -1 pl = 1 if main[idx] else -1
...@@ -323,6 +324,7 @@ if __name__ == "__main__": ...@@ -323,6 +324,7 @@ if __name__ == "__main__":
win_players.append(win_player) win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2 win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent) win_agents.append(win_agent)
# if not c1:
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}") # print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(main_reward) episode_rewards.append(main_reward)
......
...@@ -49,6 +49,8 @@ class Args: ...@@ -49,6 +49,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the path to the model checkpoint to load""" """the path to the model checkpoint to load"""
timeout: int = 600
"""the timeout of the environment step"""
debug: bool = False debug: bool = False
"""whether to run the script in debug mode""" """whether to run the script in debug mode"""
...@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
async_reset=False, async_reset=False,
greedy_reward=args.greedy_reward if not eval else True, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
timeout=args.timeout,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
return envs return envs
......
...@@ -210,17 +210,17 @@ if __name__ == "__main__": ...@@ -210,17 +210,17 @@ if __name__ == "__main__":
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if not d: if not d:
continue continue
for i in range(2): # for i in range(2):
deck_time = infos['step_time'][idx][i] # deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]] # deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name] # time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name] # avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1) # avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time # deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1 # deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0: # if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}") # print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// clang-format off // clang-format off
#include <algorithm> #include <algorithm>
#include <chrono>
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <ctime> #include <ctime>
...@@ -1525,7 +1526,8 @@ public: ...@@ -1525,7 +1526,8 @@ public:
"play_mode"_.Bind(std::string("bot")), "play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16), "verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16), "max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(true), "greedy_reward"_.Bind(true)); "record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600));
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
...@@ -1616,12 +1618,12 @@ protected: ...@@ -1616,12 +1618,12 @@ protected:
const int player_; const int player_;
PlayMode play_mode_; PlayMode play_mode_;
const bool verbose_ = false; bool verbose_ = false;
PlayerId ai_player_; PlayerId ai_player_;
intptr_t pduel_ = 0; intptr_t pduel_ = 0;
Player *players_[2]; // abstract class must be pointer std::unique_ptr<Player> players_[2]; // abstract class must be pointer
std::uniform_int_distribution<uint64_t> dist_int_; std::uniform_int_distribution<uint64_t> dist_int_;
bool done_{true}; bool done_{true};
...@@ -1708,19 +1710,25 @@ protected: ...@@ -1708,19 +1710,25 @@ protected:
// async reset // async reset
const bool async_reset_; const bool async_reset_;
int n_lives_ = 0; int n_lives_ = 0;
std::future<MDuel> duel_fut_; // std::future<MDuel> duel_fut_;
BS::thread_pool pool_; // BS::thread_pool pool_;
std::mt19937 duel_gen_; std::mt19937 duel_gen_;
public: public:
YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint32_t env_seed) // step return
float ret_reward_ = 0;
int ret_win_reason_ = 0;
YGOProEnvImpl();
YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint64_t env_seed)
: spec_(spec), dist_int_(0, 0xffffffff), : spec_(spec), dist_int_(0, 0xffffffff),
deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]), deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]),
player_(spec.config["player"_]), players_{nullptr, nullptr}, player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])), play_modes_(parse_play_modes(spec.config["play_mode"_])),
verbose_(spec.config["verbose"_]), record_(spec.config["record"_]), verbose_(spec.config["verbose"_]), record_(spec.config["record"_]),
n_history_actions_(spec.config["n_history_actions"_]), pool_(BS::thread_pool(1)), n_history_actions_(spec.config["n_history_actions"_]),
async_reset_(spec.config["async_reset"_]), greedy_reward_(spec.config["greedy_reward"_]) { async_reset_(spec.config["async_reset"_]), greedy_reward_(spec.config["greedy_reward"_]) {
if (record_) { if (record_) {
if (!verbose_) { if (!verbose_) {
...@@ -1751,14 +1759,6 @@ public: ...@@ -1751,14 +1759,6 @@ public:
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2}))); ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
} }
~YGOProEnvImpl() {
for (int i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
}
}
}
int max_options() const { return spec_.config["max_options"_]; } int max_options() const { return spec_.config["max_options"_]; }
int max_cards() const { return spec_.config["max_cards"_]; } int max_cards() const { return spec_.config["max_cards"_]; }
...@@ -1858,21 +1858,17 @@ public: ...@@ -1858,21 +1858,17 @@ public:
extra_deck1_ = mduel.extra_deck1; extra_deck1_ = mduel.extra_deck1;
for (PlayerId i = 0; i < 2; i++) { for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
}
std::string nickname = i == 0 ? "Alice" : "Bob"; std::string nickname = i == 0 ? "Alice" : "Bob";
if (i == ai_player_) { if (i == ai_player_) {
nickname = "Agent"; nickname = "Agent";
} }
nickname_[i] = nickname; nickname_[i] = nickname;
if ((play_mode_ == kHuman) && (i != ai_player_)) { if ((play_mode_ == kHuman) && (i != ai_player_)) {
players_[i] = new HumanPlayer(nickname_[i], init_lp_, i, verbose_); players_[i] = std::make_unique<HumanPlayer>(nickname, init_lp_, i, verbose_);
} else if (play_mode_ == kRandomBot) { } else if (play_mode_ == kRandomBot) {
players_[i] = new RandomAI(max_options(), dist_int_(gen_), nickname_[i], players_[i] = std::make_unique<RandomAI>(max_options(), dist_int_(gen_), nickname, init_lp_, i, verbose_);
init_lp_, i, verbose_);
} else { } else {
players_[i] = new GreedyAI(nickname_[i], init_lp_, i, verbose_); players_[i] = std::make_unique<GreedyAI>(nickname, init_lp_, i, verbose_);
} }
lp_[i] = players_[i]->init_lp_; lp_[i] = players_[i]->init_lp_;
} }
...@@ -1955,6 +1951,9 @@ public: ...@@ -1955,6 +1951,9 @@ public:
next(); next();
ret_reward_ = 0;
ret_win_reason_ = 0;
// if (async_reset_) { // if (async_reset_) {
// duel_fut_ = pool_.submit_task([ // duel_fut_ = pool_.submit_task([
// this, old_duel, duel_seed=dist_int_(gen_)] { // this, old_duel, duel_seed=dist_int_(gen_)] {
...@@ -2210,7 +2209,7 @@ public: ...@@ -2210,7 +2209,7 @@ public:
} }
} }
std::tuple<float, int> step(int idx) { void step(int idx) {
callback_(idx); callback_(idx);
update_history_actions(to_play_, legal_actions_[idx]); update_history_actions(to_play_, legal_actions_[idx]);
...@@ -2301,7 +2300,8 @@ public: ...@@ -2301,7 +2300,8 @@ public:
// if (step_time_count_ % 3000 == 0) { // if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000); // fmt::println("Step time: {:.3f}", step_time_ * 1000);
// } // }
return {reward, reason}; ret_reward_ = reward;
ret_win_reason_ = reason;
} }
using YGOProEnvSpec = EnvSpec<YGOProEnvFns>; using YGOProEnvSpec = EnvSpec<YGOProEnvFns>;
...@@ -2309,16 +2309,17 @@ public: ...@@ -2309,16 +2309,17 @@ public:
Dict<typename YGOProEnvSpec::StateKeys, Dict<typename YGOProEnvSpec::StateKeys,
typename SpecToTArray<typename YGOProEnvSpec::StateSpec::Values>::Type>; typename SpecToTArray<typename YGOProEnvSpec::StateSpec::Values>::Type>;
void WriteState( void WriteState(State &state) {
State &state, float reward, int win_reason = 0, double step_time = 0.0) { float reward = ret_reward_;
int win_reason = ret_win_reason_;
int n_options = legal_actions_.size(); int n_options = legal_actions_.size();
state["reward"_] = reward; state["reward"_] = reward;
state["info:to_play"_] = int(to_play_); state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay); state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason; state["info:win_reason"_] = win_reason;
if (reward != 0.0) { if (reward != 0.0) {
state["info:step_time"_][0] = step_time; state["info:step_time"_][0] = 0;
state["info:step_time"_][1] = step_time; state["info:step_time"_][1] = 0;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]]; state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]]; state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
} }
...@@ -3167,7 +3168,7 @@ private: ...@@ -3167,7 +3168,7 @@ private:
if (!verbose_) { if (!verbose_) {
return; return;
} }
auto player = players_[tp_]; auto& player = players_[tp_];
player->notify("Your turn."); player->notify("Your turn.");
players_[1 - tp_]->notify(fmt::format("{}'s turn.", player->nickname_)); players_[1 - tp_]->notify(fmt::format("{}'s turn.", player->nickname_));
} else if (msg_ == MSG_NEW_PHASE) { } else if (msg_ == MSG_NEW_PHASE) {
...@@ -3192,23 +3193,23 @@ private: ...@@ -3192,23 +3193,23 @@ private:
card.set_location(location); card.set_location(location);
Card cnew = c_get_card(code); Card cnew = c_get_card(code);
cnew.set_location(newloc); cnew.set_location(newloc);
auto pl = players_[card.controler_]; auto& pl = players_[card.controler_];
auto op = players_[1 - card.controler_]; auto& op = players_[1 - card.controler_];
auto plspec = card.get_spec(false); auto plspec = card.get_spec(false);
auto opspec = card.get_spec(true); auto opspec = card.get_spec(true);
auto plnewspec = cnew.get_spec(false); auto plnewspec = cnew.get_spec(false);
auto opnewspec = cnew.get_spec(true); auto opnewspec = cnew.get_spec(true);
auto getspec = [&](Player *p) { return p == pl ? plspec : opspec; }; auto getspec = [&](auto& p) { return p.get() == pl.get() ? plspec : opspec; };
auto getnewspec = [&](Player *p) { auto getnewspec = [&](auto& p) {
return p == pl ? plnewspec : opnewspec; return p.get() == pl.get() ? plnewspec : opnewspec;
}; };
bool card_visible = true; bool card_visible = true;
if ((card.position_ & POS_FACEDOWN) && (cnew.position_ & POS_FACEDOWN)) { if ((card.position_ & POS_FACEDOWN) && (cnew.position_ & POS_FACEDOWN)) {
card_visible = false; card_visible = false;
} }
auto getvisiblename = [&](Player *p) { auto getvisiblename = [&](auto& p) {
return card_visible ? card.name_ : "Face-down card"; return card_visible ? card.name_ : "Face-down card";
}; };
...@@ -3336,8 +3337,8 @@ private: ...@@ -3336,8 +3337,8 @@ private:
Card card = c_get_card(code); Card card = c_get_card(code);
card.set_location(location); card.set_location(location);
auto c = card.controler_; auto c = card.controler_;
auto cpl = players_[c]; auto& cpl = players_[c];
auto opl = players_[1 - c]; auto& opl = players_[1 - c];
cpl->notify(fmt::format("You set {} ({}) in {} position.", card.name_, cpl->notify(fmt::format("You set {} ({}) in {} position.", card.name_,
card.get_spec(c), card.get_position())); card.get_spec(c), card.get_position()));
opl->notify(fmt::format("{} sets {} in {} position.", cpl->nickname_, opl->notify(fmt::format("{} sets {} in {} position.", cpl->nickname_,
...@@ -3435,8 +3436,8 @@ private: ...@@ -3435,8 +3436,8 @@ private:
uint8_t prevpos = card.position_; uint8_t prevpos = card.position_;
card.position_ = read_u8(); card.position_ = read_u8();
auto pl = players_[card.controler_]; auto& pl = players_[card.controler_];
auto op = players_[1 - card.controler_]; auto& op = players_[1 - card.controler_];
auto plspec = card.get_spec(false); auto plspec = card.get_spec(false);
auto opspec = card.get_spec(true); auto opspec = card.get_spec(true);
auto prevpos_str = position_to_string(prevpos); auto prevpos_str = position_to_string(prevpos);
...@@ -3482,7 +3483,7 @@ private: ...@@ -3482,7 +3483,7 @@ private:
} }
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl]; auto& p = players_[pl];
if (pl == player) { if (pl == player) {
p->notify(fmt::format("You reveal {} cards from your deck:", size)); p->notify(fmt::format("You reveal {} cards from your deck:", size));
} else { } else {
...@@ -3514,7 +3515,7 @@ private: ...@@ -3514,7 +3515,7 @@ private:
} }
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl]; auto& p = players_[pl];
auto s = "card is"; auto s = "card is";
if (count > 1) { if (count > 1) {
s = "cards are"; s = "cards are";
...@@ -3553,7 +3554,7 @@ private: ...@@ -3553,7 +3554,7 @@ private:
Card card1 = get_card(c1, l1, s1); Card card1 = get_card(c1, l1, s1);
Card card2 = get_card(c2, l2, s2); Card card2 = get_card(c2, l2, s2);
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl]; auto& p = players_[pl];
auto spec1 = card1.get_spec(pl); auto spec1 = card1.get_spec(pl);
auto spec2 = card2.get_spec(pl); auto spec2 = card2.get_spec(pl);
auto c1name = card1.name_; auto c1name = card1.name_;
...@@ -3584,8 +3585,8 @@ private: ...@@ -3584,8 +3585,8 @@ private:
return; return;
} }
auto pl = players_[player]; auto& pl = players_[player];
auto op = players_[1 - player]; auto& op = players_[1 - player];
op->notify(fmt::format("{} shows you {} cards.", pl->nickname_, size)); op->notify(fmt::format("{} shows you {} cards.", pl->nickname_, size));
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
...@@ -3624,7 +3625,7 @@ private: ...@@ -3624,7 +3625,7 @@ private:
auto seq = read_u8(); auto seq = read_u8();
cards.push_back(get_card(c, loc, seq)); cards.push_back(get_card(c, loc, seq));
} }
auto pl = players_[player]; auto& pl = players_[player];
pl->notify( pl->notify(
"Sort " + std::to_string(size) + "Sort " + std::to_string(size) +
" cards by entering numbers separated by spaces (c = cancel):"); " cards by entering numbers separated by spaces (c = cancel):");
...@@ -3677,9 +3678,9 @@ private: ...@@ -3677,9 +3678,9 @@ private:
auto seq = read_u8(); auto seq = read_u8();
auto count = read_u16(); auto count = read_u16();
auto c = get_card(player, loc, seq); auto c = get_card(player, loc, seq);
auto pl = players_[player]; auto& pl = players_[player];
PlayerId op_id = 1 - player; PlayerId op_id = 1 - player;
auto op = players_[op_id]; auto& op = players_[op_id];
// TODO(3): counter type to string // TODO(3): counter type to string
pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player))); pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id))); op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
...@@ -3694,9 +3695,9 @@ private: ...@@ -3694,9 +3695,9 @@ private:
auto seq = read_u8(); auto seq = read_u8();
auto count = read_u16(); auto count = read_u16();
auto c = get_card(player, loc, seq); auto c = get_card(player, loc, seq);
auto pl = players_[player]; auto& pl = players_[player];
PlayerId op_id = 1 - player; PlayerId op_id = 1 - player;
auto op = players_[op_id]; auto& op = players_[op_id];
pl->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(player))); pl->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(op_id))); op->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
} else if (msg_ == MSG_ATTACK_DISABLED) { } else if (msg_ == MSG_ATTACK_DISABLED) {
...@@ -3720,8 +3721,8 @@ private: ...@@ -3720,8 +3721,8 @@ private:
return; return;
} }
auto player = read_u8(); auto player = read_u8();
auto pl = players_[player]; auto& pl = players_[player];
auto op = players_[1 - player]; auto& op = players_[1 - player];
pl->notify("You shuffled your deck."); pl->notify("You shuffled your deck.");
op->notify(pl->nickname_ + " shuffled their deck."); op->notify(pl->nickname_ + " shuffled their deck.");
} else if (msg_ == MSG_SHUFFLE_EXTRA) { } else if (msg_ == MSG_SHUFFLE_EXTRA) {
...@@ -3734,8 +3735,8 @@ private: ...@@ -3734,8 +3735,8 @@ private:
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
read_u32(); read_u32();
} }
auto pl = players_[player]; auto& pl = players_[player];
auto op = players_[1 - player]; auto& op = players_[1 - player];
pl->notify(fmt::format("You shuffled your extra deck ({}).", count)); pl->notify(fmt::format("You shuffled your extra deck ({}).", count));
op->notify(fmt::format("{} shuffled their extra deck ({}).", pl->nickname_, count)); op->notify(fmt::format("{} shuffled their extra deck ({}).", pl->nickname_, count));
} else if (msg_ == MSG_SHUFFLE_HAND) { } else if (msg_ == MSG_SHUFFLE_HAND) {
...@@ -3747,8 +3748,8 @@ private: ...@@ -3747,8 +3748,8 @@ private:
auto player = read_u8(); auto player = read_u8();
dp_ = dl_; dp_ = dl_;
auto pl = players_[player]; auto& pl = players_[player];
auto op = players_[1 - player]; auto& op = players_[1 - player];
pl->notify("You shuffled your hand."); pl->notify("You shuffled your hand.");
op->notify(pl->nickname_ + " shuffled their hand."); op->notify(pl->nickname_ + " shuffled their hand.");
} else if (msg_ == MSG_SUMMONED) { } else if (msg_ == MSG_SUMMONED) {
...@@ -3762,7 +3763,7 @@ private: ...@@ -3762,7 +3763,7 @@ private:
Card card = c_get_card(code); Card card = c_get_card(code);
card.set_location(read_u32()); card.set_location(read_u32());
const auto &nickname = players_[card.controler_]->nickname_; const auto &nickname = players_[card.controler_]->nickname_;
for (auto pl : players_) { for (auto& pl : players_) {
pl->notify(nickname + " summoning " + card.name_ + " (" + pl->notify(nickname + " summoning " + card.name_ + " (" +
std::to_string(card.attack_) + "/" + std::to_string(card.attack_) + "/" +
std::to_string(card.defense_) + ") in " + std::to_string(card.defense_) + ") in " +
...@@ -3783,7 +3784,7 @@ private: ...@@ -3783,7 +3784,7 @@ private:
Card card = c_get_card(code); Card card = c_get_card(code);
card.set_location(location); card.set_location(location);
auto cpl = players_[card.controler_]; auto& cpl = players_[card.controler_];
for (PlayerId pl = 0; pl < 2; pl++) { for (PlayerId pl = 0; pl < 2; pl++) {
auto spec = card.get_spec(pl); auto spec = card.get_spec(pl);
players_[1 - pl]->notify(cpl->nickname_ + " flip summons " + spec + players_[1 - pl]->notify(cpl->nickname_ + " flip summons " + spec +
...@@ -3799,7 +3800,7 @@ private: ...@@ -3799,7 +3800,7 @@ private:
card.set_location(read_u32()); card.set_location(read_u32());
const auto &nickname = players_[card.controler_]->nickname_; const auto &nickname = players_[card.controler_]->nickname_;
for (PlayerId p = 0; p < 2; p++) { for (PlayerId p = 0; p < 2; p++) {
auto pl = players_[p]; auto& pl = players_[p];
auto pos = card.get_position(); auto pos = card.get_position();
auto atk = std::to_string(card.attack_); auto atk = std::to_string(card.attack_);
auto def = std::to_string(card.defense_); auto def = std::to_string(card.defense_);
...@@ -3868,7 +3869,7 @@ private: ...@@ -3868,7 +3869,7 @@ private:
if (!verbose_) { if (!verbose_) {
return; return;
} }
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("You pay " + std::to_string(cost) + " LP. Your LP is now " + pl->notify("You pay " + std::to_string(cost) + " LP. Your LP is now " +
std::to_string(lp_[player]) + "."); std::to_string(lp_[player]) + ".");
players_[1 - player]->notify( players_[1 - player]->notify(
...@@ -3959,7 +3960,7 @@ private: ...@@ -3959,7 +3960,7 @@ private:
tcard = get_card(tc, tloc, tseq); tcard = get_card(tc, tloc, tseq);
} }
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto pl = players_[i]; auto& pl = players_[i];
std::string attacker_points; std::string attacker_points;
if (acard.type_ & TYPE_LINK) { if (acard.type_ & TYPE_LINK) {
attacker_points = std::to_string(aa); attacker_points = std::to_string(aa);
...@@ -3982,8 +3983,8 @@ private: ...@@ -3982,8 +3983,8 @@ private:
} else if (msg_ == MSG_WIN) { } else if (msg_ == MSG_WIN) {
auto player = read_u8(); auto player = read_u8();
auto reason = read_u8(); auto reason = read_u8();
auto winner = players_[player]; auto& winner = players_[player];
auto loser = players_[1 - player]; auto& loser = players_[1 - player];
_duel_end(player, reason); _duel_end(player, reason);
...@@ -4001,7 +4002,7 @@ private: ...@@ -4001,7 +4002,7 @@ private:
bool to_m2 = read_u8(); bool to_m2 = read_u8();
bool to_ep = read_u8(); bool to_ep = read_u8();
auto pl = players_[player]; auto& pl = players_[player];
if (verbose_) { if (verbose_) {
pl->notify("Battle menu:"); pl->notify("Battle menu:");
} }
...@@ -4096,7 +4097,7 @@ private: ...@@ -4096,7 +4097,7 @@ private:
std::vector<std::string> select_specs; std::vector<std::string> select_specs;
select_specs.reserve(select_size); select_specs.reserve(select_size);
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " + pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:"); std::to_string(max) + " cards:");
for (int i = 0; i < select_size; ++i) { for (int i = 0; i < select_size; ++i) {
...@@ -4169,7 +4170,7 @@ private: ...@@ -4169,7 +4170,7 @@ private:
card.set_location(loc); card.set_location(loc);
cards.push_back(card); cards.push_back(card);
} }
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " + pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards separated by spaces:"); std::to_string(max) + " cards separated by spaces:");
for (const auto &card : cards) { for (const auto &card : cards) {
...@@ -4255,7 +4256,7 @@ private: ...@@ -4255,7 +4256,7 @@ private:
cards.push_back(card); cards.push_back(card);
release_params.push_back(release_param); release_params.push_back(release_param);
} }
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " + pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + std::to_string(max) +
" cards to tribute separated by spaces:"); " cards to tribute separated by spaces:");
...@@ -4348,7 +4349,7 @@ private: ...@@ -4348,7 +4349,7 @@ private:
must_select.push_back(card); must_select.push_back(card);
expected -= (param & 0xff); expected -= (param & 0xff);
} }
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("Select cards with a total value of " + pl->notify("Select cards with a total value of " +
std::to_string(expected) + ", seperated by spaces."); std::to_string(expected) + ", seperated by spaces.");
for (const auto &card : must_select) { for (const auto &card : must_select) {
...@@ -4386,7 +4387,7 @@ private: ...@@ -4386,7 +4387,7 @@ private:
select.push_back(card); select.push_back(card);
select_params.push_back(param); select_params.push_back(param);
} }
auto pl = players_[player]; auto& pl = players_[player];
for (const auto &card : select) { for (const auto &card : select) {
auto spec = card.get_spec(player); auto spec = card.get_spec(player);
select_specs.push_back(spec); select_specs.push_back(spec);
...@@ -4472,8 +4473,8 @@ private: ...@@ -4472,8 +4473,8 @@ private:
return; return;
} }
auto pl = players_[player]; auto& pl = players_[player];
auto op = players_[1 - player]; auto& op = players_[1 - player];
chaining_player_ = player; chaining_player_ = player;
if (!op->seen_waiting_) { if (!op->seen_waiting_) {
if (verbose_) { if (verbose_) {
...@@ -4543,7 +4544,7 @@ private: ...@@ -4543,7 +4544,7 @@ private:
} }
legal_actions_.push_back(la); legal_actions_.push_back(la);
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
std::string s; std::string s;
if (code == 0) { if (code == 0) {
s = get_system_string(eff_idx); s = get_system_string(eff_idx);
...@@ -4596,7 +4597,7 @@ private: ...@@ -4596,7 +4597,7 @@ private:
if (verbose_) { if (verbose_) {
Card c = c_get_card(code); Card c = c_get_card(code);
auto pl = players_[player]; auto& pl = players_[player];
auto name = c.name_; auto name = c.name_;
std::string s; std::string s;
if (code_d == 0) { if (code_d == 0) {
...@@ -4695,7 +4696,7 @@ private: ...@@ -4695,7 +4696,7 @@ private:
int offset = 0; int offset = 0;
auto pl = players_[player]; auto& pl = players_[player];
if (verbose_) { if (verbose_) {
pl->notify("Select a card and action to perform."); pl->notify("Select a card and action to perform.");
} }
...@@ -4896,7 +4897,7 @@ private: ...@@ -4896,7 +4897,7 @@ private:
throw std::runtime_error("Select counter count " + throw std::runtime_error("Select counter count " +
std::to_string(count) + " not implemented"); std::to_string(count) + " not implemented");
} }
auto pl = players_[player]; auto& pl = players_[player];
if (verbose_) { if (verbose_) {
pl->notify(fmt::format("Type new {} for {} card(s), separated by spaces.", "UNKNOWN_COUNTER", count)); pl->notify(fmt::format("Type new {} for {} card(s), separated by spaces.", "UNKNOWN_COUNTER", count));
} }
...@@ -4943,7 +4944,7 @@ private: ...@@ -4943,7 +4944,7 @@ private:
legal_actions_.push_back(LegalAction::number(number)); legal_actions_.push_back(LegalAction::number(number));
} }
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
std::string str = "Select a number, one of:"; std::string str = "Select a number, one of:";
pl->notify(str); pl->notify(str);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
...@@ -4974,7 +4975,7 @@ private: ...@@ -4974,7 +4975,7 @@ private:
} }
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
pl->notify("Select " + std::to_string(count) + pl->notify("Select " + std::to_string(count) +
" attributes separated by spaces:"); " attributes separated by spaces:");
for (int i = 0; i < attrs.size(); i++) { for (int i = 0; i < attrs.size(); i++) {
...@@ -5002,7 +5003,7 @@ private: ...@@ -5002,7 +5003,7 @@ private:
CardId cid = c_get_card_id(code); CardId cid = c_get_card_id(code);
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
auto card = c_get_card(code); auto card = c_get_card(code);
pl->notify("Select position for " + card.name_ + ":"); pl->notify("Select position for " + card.name_ + ":");
} }
...@@ -5016,7 +5017,7 @@ private: ...@@ -5016,7 +5017,7 @@ private:
legal_actions_.push_back(la); legal_actions_.push_back(la);
int cmd_idx = legal_actions_.size(); int cmd_idx = legal_actions_.size();
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto& pl = players_[player];
pl->notify(fmt::format("{}: {}", cmd_idx, position_to_string(pos))); pl->notify(fmt::format("{}: {}", cmd_idx, position_to_string(pos)));
} }
} }
...@@ -5040,7 +5041,7 @@ private: ...@@ -5040,7 +5041,7 @@ private:
void _damage(uint8_t player, uint32_t amount) { void _damage(uint8_t player, uint32_t amount) {
lp_[player] -= amount; lp_[player] -= amount;
if (verbose_) { if (verbose_) {
auto lp = players_[player]; auto& lp = players_[player];
lp->notify(fmt::format("Your lp decreased by {}, now {}", amount, lp_[player])); lp->notify(fmt::format("Your lp decreased by {}, now {}", amount, lp_[player]));
players_[1 - player]->notify(fmt::format("{}'s lp decreased by {}, now {}", players_[1 - player]->notify(fmt::format("{}'s lp decreased by {}, now {}",
lp->nickname_, amount, lp_[player])); lp->nickname_, amount, lp_[player]));
...@@ -5050,7 +5051,7 @@ private: ...@@ -5050,7 +5051,7 @@ private:
void _recover(uint8_t player, uint32_t amount) { void _recover(uint8_t player, uint32_t amount) {
lp_[player] += amount; lp_[player] += amount;
if (verbose_) { if (verbose_) {
auto lp = players_[player]; auto& lp = players_[player];
lp->notify(fmt::format("Your lp increased by {}, now {}", amount, lp_[player])); lp->notify(fmt::format("Your lp increased by {}, now {}", amount, lp_[player]));
players_[1 - player]->notify(fmt::format("{}'s lp increased by {}, now {}", players_[1 - player]->notify(fmt::format("{}'s lp increased by {}, now {}",
lp->nickname_, amount, lp_[player])); lp->nickname_, amount, lp_[player]));
...@@ -5074,34 +5075,113 @@ private: ...@@ -5074,34 +5075,113 @@ private:
class YGOProEnv : public Env<YGOProEnvSpec> { class YGOProEnv : public Env<YGOProEnvSpec> {
protected: protected:
const int max_episode_steps_; const int max_episode_steps_;
const int timeout_;
int elapsed_step_; int elapsed_step_;
std::uniform_int_distribution<uint64_t> dist_int_; std::uniform_int_distribution<uint64_t> dist_int_;
YGOProEnvImpl env_impl_;
// The pool can't be in vector, so we create multiple pools manually
BS::thread_pool pool0_;
BS::thread_pool pool1_;
BS::thread_pool pool2_;
BS::thread_pool pool3_;
BS::thread_pool pool4_;
const int max_timeout_{5};
// YGOProEnvImpl env_impl0_;
// YGOProEnvImpl env_impl1_;
// YGOProEnvImpl env_impl2_;
// YGOProEnvImpl env_impl3_;
// YGOProEnvImpl env_impl4_;
std::vector<YGOProEnvImpl> env_impls_;
bool done_{true};
public: public:
YGOProEnv(const Spec &spec, int env_id) YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id), : Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]), max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff), elapsed_step_(max_episode_steps_ + 1),
env_impl_(YGOProEnvImpl(spec, dist_int_(gen_))) { timeout_(spec.config["timeout"_]),
pool0_(1), pool1_(1), pool2_(1), pool3_(1), pool4_(1),
dist_int_(0, 0xffffffff) {
env_impls_.reserve(max_timeout_);
env_impls_.emplace_back(spec, dist_int_(gen_));
}
bool IsDone() override { return done_; }
BS::thread_pool& get_pool(int idx) {
switch (idx) {
case 0: return pool0_;
case 1: return pool1_;
case 2: return pool2_;
case 3: return pool3_;
case 4: return pool4_;
default: throw std::runtime_error("Invalid pool index");
}
} }
bool IsDone() override { return env_impl_.done(); } void handle_timeout() {
env_impls_.emplace_back(spec_, dist_int_(gen_));
if (env_impls_.capacity() > max_timeout_) {
throw std::runtime_error("Too many timeouts");
}
done_ = true;
State state = Allocate();
state["reward"_] = 1.0;
state["info:to_play"_] = 1;
state["info:is_selfplay"_] = 1;
state["info:win_reason"_] = 1;
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
}
void Reset() override { void Reset() override {
env_impl_.reset(); int idx = env_impls_.size() - 1;
elapsed_step_ = 0; auto& pool = get_pool(idx);
auto fut = pool.submit_task([this, idx]() {
env_impls_[idx].reset();
});
if (fut.wait_for(std::chrono::seconds(timeout_)) != std::future_status::ready) {
throw std::runtime_error("Reset timeout");
}
auto &env_impl = env_impls_[idx];
elapsed_step_ = 0;
done_ = false;
State state = Allocate(); State state = Allocate();
env_impl_.WriteState(state, 0.0); env_impl.WriteState(state);
} }
void Step(const Action &action) override { void Step(const Action &action) override {
auto [reward, win_reason] = env_impl_.step(action["action"_]); int idx = env_impls_.size() - 1;
auto& pool = get_pool(idx);
int action_idx = action["action"_];
pool.detach_task([this, action_idx, idx]() {
// Test timeout: random sleep with probability 0.01
// if (dist_int_(gen_) % 10000 == 0) {
// fmt::println("Env {} sleep {}", env_id_, env_impls_.capacity());
// std::this_thread::sleep_for(std::chrono::seconds(5));
// fmt::println("Env {} after {}", env_id_, env_impls_.capacity());
// auto& env_impl = env_impls_[idx];
// env_impl.step(action_idx);
// std::this_thread::sleep_for(std::chrono::seconds(1));
// return;
// }
env_impls_[idx].step(action_idx);
});
if (!pool.wait_for(std::chrono::seconds(timeout_))) {
handle_timeout();
fmt::println("Env {} timeout, new env created", env_id_);
} else {
auto& env_impl = env_impls_[idx];
done_ = env_impl.ret_reward_ != 0;
State state = Allocate(); State state = Allocate();
env_impl_.WriteState(state, reward, win_reason); env_impl.WriteState(state);
}
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment