Commit c90562de authored by sbl1996@126.com's avatar sbl1996@126.com

Env fault tolerant training via timeout

parent b7e43382
......@@ -153,7 +153,7 @@ if __name__ == "__main__":
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
record=args.record,
)
envs1 = ygoenv.make(
task_id=env_id1,
......@@ -311,6 +311,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones1):
if not d or (args.accurate and collected[idx]):
continue
# c1 = collected[idx]
collected[idx] = True
win_reason = infos1['win_reason'][idx]
pl = 1 if main[idx] else -1
......@@ -323,7 +324,8 @@ if __name__ == "__main__":
win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent)
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
# if not c1:
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths.append(episode_length)
episode_rewards.append(main_reward)
win_rates.append(win)
......
......@@ -49,6 +49,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
timeout: int = 600
"""the timeout of the environment step"""
debug: bool = False
"""whether to run the script in debug mode"""
......@@ -208,6 +210,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
timeout=args.timeout,
)
envs.num_envs = num_envs
return envs
......
......@@ -210,17 +210,17 @@ if __name__ == "__main__":
for idx, d in enumerate(dones):
if not d:
continue
for i in range(2):
deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}")
# for i in range(2):
# deck_time = infos['step_time'][idx][i]
# deck_name = deck_names[infos['deck'][idx][i]]
# time_count = deck_time_count[deck_name]
# avg_time = deck_times[deck_name]
# avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
# deck_times[deck_name] = avg_time
# deck_time_count[deck_name] += 1
# if deck_time_count[deck_name] % 100 == 0:
# print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
......
......@@ -3,6 +3,7 @@
// clang-format off
#include <algorithm>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <ctime>
......@@ -1525,7 +1526,8 @@ public:
"play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(true), "greedy_reward"_.Bind(true));
"record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1616,12 +1618,12 @@ protected:
const int player_;
PlayMode play_mode_;
const bool verbose_ = false;
bool verbose_ = false;
PlayerId ai_player_;
intptr_t pduel_ = 0;
Player *players_[2]; // abstract class must be pointer
std::unique_ptr<Player> players_[2]; // abstract class must be pointer
std::uniform_int_distribution<uint64_t> dist_int_;
bool done_{true};
......@@ -1708,19 +1710,25 @@ protected:
// async reset
const bool async_reset_;
int n_lives_ = 0;
std::future<MDuel> duel_fut_;
BS::thread_pool pool_;
// std::future<MDuel> duel_fut_;
// BS::thread_pool pool_;
std::mt19937 duel_gen_;
public:
YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint32_t env_seed)
// step return
float ret_reward_ = 0;
int ret_win_reason_ = 0;
YGOProEnvImpl();
YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint64_t env_seed)
: spec_(spec), dist_int_(0, 0xffffffff),
deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]),
player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])),
verbose_(spec.config["verbose"_]), record_(spec.config["record"_]),
n_history_actions_(spec.config["n_history_actions"_]), pool_(BS::thread_pool(1)),
n_history_actions_(spec.config["n_history_actions"_]),
async_reset_(spec.config["async_reset"_]), greedy_reward_(spec.config["greedy_reward"_]) {
if (record_) {
if (!verbose_) {
......@@ -1751,14 +1759,6 @@ public:
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
}
~YGOProEnvImpl() {
for (int i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
}
}
}
int max_options() const { return spec_.config["max_options"_]; }
int max_cards() const { return spec_.config["max_cards"_]; }
......@@ -1858,21 +1858,17 @@ public:
extra_deck1_ = mduel.extra_deck1;
for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
}
std::string nickname = i == 0 ? "Alice" : "Bob";
if (i == ai_player_) {
nickname = "Agent";
}
nickname_[i] = nickname;
if ((play_mode_ == kHuman) && (i != ai_player_)) {
players_[i] = new HumanPlayer(nickname_[i], init_lp_, i, verbose_);
players_[i] = std::make_unique<HumanPlayer>(nickname, init_lp_, i, verbose_);
} else if (play_mode_ == kRandomBot) {
players_[i] = new RandomAI(max_options(), dist_int_(gen_), nickname_[i],
init_lp_, i, verbose_);
players_[i] = std::make_unique<RandomAI>(max_options(), dist_int_(gen_), nickname, init_lp_, i, verbose_);
} else {
players_[i] = new GreedyAI(nickname_[i], init_lp_, i, verbose_);
players_[i] = std::make_unique<GreedyAI>(nickname, init_lp_, i, verbose_);
}
lp_[i] = players_[i]->init_lp_;
}
......@@ -1955,6 +1951,9 @@ public:
next();
ret_reward_ = 0;
ret_win_reason_ = 0;
// if (async_reset_) {
// duel_fut_ = pool_.submit_task([
// this, old_duel, duel_seed=dist_int_(gen_)] {
......@@ -2210,7 +2209,7 @@ public:
}
}
std::tuple<float, int> step(int idx) {
void step(int idx) {
callback_(idx);
update_history_actions(to_play_, legal_actions_[idx]);
......@@ -2301,7 +2300,8 @@ public:
// if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000);
// }
return {reward, reason};
ret_reward_ = reward;
ret_win_reason_ = reason;
}
using YGOProEnvSpec = EnvSpec<YGOProEnvFns>;
......@@ -2309,16 +2309,17 @@ public:
Dict<typename YGOProEnvSpec::StateKeys,
typename SpecToTArray<typename YGOProEnvSpec::StateSpec::Values>::Type>;
void WriteState(
State &state, float reward, int win_reason = 0, double step_time = 0.0) {
void WriteState(State &state) {
float reward = ret_reward_;
int win_reason = ret_win_reason_;
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:step_time"_][0] = 0;
state["info:step_time"_][1] = 0;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
......@@ -3167,7 +3168,7 @@ private:
if (!verbose_) {
return;
}
auto player = players_[tp_];
auto& player = players_[tp_];
player->notify("Your turn.");
players_[1 - tp_]->notify(fmt::format("{}'s turn.", player->nickname_));
} else if (msg_ == MSG_NEW_PHASE) {
......@@ -3192,23 +3193,23 @@ private:
card.set_location(location);
Card cnew = c_get_card(code);
cnew.set_location(newloc);
auto pl = players_[card.controler_];
auto op = players_[1 - card.controler_];
auto& pl = players_[card.controler_];
auto& op = players_[1 - card.controler_];
auto plspec = card.get_spec(false);
auto opspec = card.get_spec(true);
auto plnewspec = cnew.get_spec(false);
auto opnewspec = cnew.get_spec(true);
auto getspec = [&](Player *p) { return p == pl ? plspec : opspec; };
auto getnewspec = [&](Player *p) {
return p == pl ? plnewspec : opnewspec;
auto getspec = [&](auto& p) { return p.get() == pl.get() ? plspec : opspec; };
auto getnewspec = [&](auto& p) {
return p.get() == pl.get() ? plnewspec : opnewspec;
};
bool card_visible = true;
if ((card.position_ & POS_FACEDOWN) && (cnew.position_ & POS_FACEDOWN)) {
card_visible = false;
}
auto getvisiblename = [&](Player *p) {
auto getvisiblename = [&](auto& p) {
return card_visible ? card.name_ : "Face-down card";
};
......@@ -3336,8 +3337,8 @@ private:
Card card = c_get_card(code);
card.set_location(location);
auto c = card.controler_;
auto cpl = players_[c];
auto opl = players_[1 - c];
auto& cpl = players_[c];
auto& opl = players_[1 - c];
cpl->notify(fmt::format("You set {} ({}) in {} position.", card.name_,
card.get_spec(c), card.get_position()));
opl->notify(fmt::format("{} sets {} in {} position.", cpl->nickname_,
......@@ -3435,8 +3436,8 @@ private:
uint8_t prevpos = card.position_;
card.position_ = read_u8();
auto pl = players_[card.controler_];
auto op = players_[1 - card.controler_];
auto& pl = players_[card.controler_];
auto& op = players_[1 - card.controler_];
auto plspec = card.get_spec(false);
auto opspec = card.get_spec(true);
auto prevpos_str = position_to_string(prevpos);
......@@ -3482,7 +3483,7 @@ private:
}
for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl];
auto& p = players_[pl];
if (pl == player) {
p->notify(fmt::format("You reveal {} cards from your deck:", size));
} else {
......@@ -3514,7 +3515,7 @@ private:
}
for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl];
auto& p = players_[pl];
auto s = "card is";
if (count > 1) {
s = "cards are";
......@@ -3553,7 +3554,7 @@ private:
Card card1 = get_card(c1, l1, s1);
Card card2 = get_card(c2, l2, s2);
for (PlayerId pl = 0; pl < 2; pl++) {
auto p = players_[pl];
auto& p = players_[pl];
auto spec1 = card1.get_spec(pl);
auto spec2 = card2.get_spec(pl);
auto c1name = card1.name_;
......@@ -3584,8 +3585,8 @@ private:
return;
}
auto pl = players_[player];
auto op = players_[1 - player];
auto& pl = players_[player];
auto& op = players_[1 - player];
op->notify(fmt::format("{} shows you {} cards.", pl->nickname_, size));
for (int i = 0; i < size; ++i) {
......@@ -3624,7 +3625,7 @@ private:
auto seq = read_u8();
cards.push_back(get_card(c, loc, seq));
}
auto pl = players_[player];
auto& pl = players_[player];
pl->notify(
"Sort " + std::to_string(size) +
" cards by entering numbers separated by spaces (c = cancel):");
......@@ -3677,9 +3678,9 @@ private:
auto seq = read_u8();
auto count = read_u16();
auto c = get_card(player, loc, seq);
auto pl = players_[player];
auto& pl = players_[player];
PlayerId op_id = 1 - player;
auto op = players_[op_id];
auto& op = players_[op_id];
// TODO(3): counter type to string
pl->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} placed on {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
......@@ -3694,9 +3695,9 @@ private:
auto seq = read_u8();
auto count = read_u16();
auto c = get_card(player, loc, seq);
auto pl = players_[player];
auto& pl = players_[player];
PlayerId op_id = 1 - player;
auto op = players_[op_id];
auto& op = players_[op_id];
pl->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(player)));
op->notify(fmt::format("{} counter(s) of type {} removed from {} ().", count, "UNK", c.name_, c.get_spec(op_id)));
} else if (msg_ == MSG_ATTACK_DISABLED) {
......@@ -3720,8 +3721,8 @@ private:
return;
}
auto player = read_u8();
auto pl = players_[player];
auto op = players_[1 - player];
auto& pl = players_[player];
auto& op = players_[1 - player];
pl->notify("You shuffled your deck.");
op->notify(pl->nickname_ + " shuffled their deck.");
} else if (msg_ == MSG_SHUFFLE_EXTRA) {
......@@ -3734,8 +3735,8 @@ private:
for (int i = 0; i < count; ++i) {
read_u32();
}
auto pl = players_[player];
auto op = players_[1 - player];
auto& pl = players_[player];
auto& op = players_[1 - player];
pl->notify(fmt::format("You shuffled your extra deck ({}).", count));
op->notify(fmt::format("{} shuffled their extra deck ({}).", pl->nickname_, count));
} else if (msg_ == MSG_SHUFFLE_HAND) {
......@@ -3747,8 +3748,8 @@ private:
auto player = read_u8();
dp_ = dl_;
auto pl = players_[player];
auto op = players_[1 - player];
auto& pl = players_[player];
auto& op = players_[1 - player];
pl->notify("You shuffled your hand.");
op->notify(pl->nickname_ + " shuffled their hand.");
} else if (msg_ == MSG_SUMMONED) {
......@@ -3762,7 +3763,7 @@ private:
Card card = c_get_card(code);
card.set_location(read_u32());
const auto &nickname = players_[card.controler_]->nickname_;
for (auto pl : players_) {
for (auto& pl : players_) {
pl->notify(nickname + " summoning " + card.name_ + " (" +
std::to_string(card.attack_) + "/" +
std::to_string(card.defense_) + ") in " +
......@@ -3783,7 +3784,7 @@ private:
Card card = c_get_card(code);
card.set_location(location);
auto cpl = players_[card.controler_];
auto& cpl = players_[card.controler_];
for (PlayerId pl = 0; pl < 2; pl++) {
auto spec = card.get_spec(pl);
players_[1 - pl]->notify(cpl->nickname_ + " flip summons " + spec +
......@@ -3799,7 +3800,7 @@ private:
card.set_location(read_u32());
const auto &nickname = players_[card.controler_]->nickname_;
for (PlayerId p = 0; p < 2; p++) {
auto pl = players_[p];
auto& pl = players_[p];
auto pos = card.get_position();
auto atk = std::to_string(card.attack_);
auto def = std::to_string(card.defense_);
......@@ -3868,7 +3869,7 @@ private:
if (!verbose_) {
return;
}
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("You pay " + std::to_string(cost) + " LP. Your LP is now " +
std::to_string(lp_[player]) + ".");
players_[1 - player]->notify(
......@@ -3959,7 +3960,7 @@ private:
tcard = get_card(tc, tloc, tseq);
}
for (int i = 0; i < 2; i++) {
auto pl = players_[i];
auto& pl = players_[i];
std::string attacker_points;
if (acard.type_ & TYPE_LINK) {
attacker_points = std::to_string(aa);
......@@ -3982,8 +3983,8 @@ private:
} else if (msg_ == MSG_WIN) {
auto player = read_u8();
auto reason = read_u8();
auto winner = players_[player];
auto loser = players_[1 - player];
auto& winner = players_[player];
auto& loser = players_[1 - player];
_duel_end(player, reason);
......@@ -4001,7 +4002,7 @@ private:
bool to_m2 = read_u8();
bool to_ep = read_u8();
auto pl = players_[player];
auto& pl = players_[player];
if (verbose_) {
pl->notify("Battle menu:");
}
......@@ -4096,7 +4097,7 @@ private:
std::vector<std::string> select_specs;
select_specs.reserve(select_size);
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards:");
for (int i = 0; i < select_size; ++i) {
......@@ -4169,7 +4170,7 @@ private:
card.set_location(loc);
cards.push_back(card);
}
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) + " cards separated by spaces:");
for (const auto &card : cards) {
......@@ -4255,7 +4256,7 @@ private:
cards.push_back(card);
release_params.push_back(release_param);
}
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("Select " + std::to_string(min) + " to " +
std::to_string(max) +
" cards to tribute separated by spaces:");
......@@ -4348,7 +4349,7 @@ private:
must_select.push_back(card);
expected -= (param & 0xff);
}
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("Select cards with a total value of " +
std::to_string(expected) + ", seperated by spaces.");
for (const auto &card : must_select) {
......@@ -4386,7 +4387,7 @@ private:
select.push_back(card);
select_params.push_back(param);
}
auto pl = players_[player];
auto& pl = players_[player];
for (const auto &card : select) {
auto spec = card.get_spec(player);
select_specs.push_back(spec);
......@@ -4472,8 +4473,8 @@ private:
return;
}
auto pl = players_[player];
auto op = players_[1 - player];
auto& pl = players_[player];
auto& op = players_[1 - player];
chaining_player_ = player;
if (!op->seen_waiting_) {
if (verbose_) {
......@@ -4543,7 +4544,7 @@ private:
}
legal_actions_.push_back(la);
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
std::string s;
if (code == 0) {
s = get_system_string(eff_idx);
......@@ -4596,7 +4597,7 @@ private:
if (verbose_) {
Card c = c_get_card(code);
auto pl = players_[player];
auto& pl = players_[player];
auto name = c.name_;
std::string s;
if (code_d == 0) {
......@@ -4695,7 +4696,7 @@ private:
int offset = 0;
auto pl = players_[player];
auto& pl = players_[player];
if (verbose_) {
pl->notify("Select a card and action to perform.");
}
......@@ -4896,7 +4897,7 @@ private:
throw std::runtime_error("Select counter count " +
std::to_string(count) + " not implemented");
}
auto pl = players_[player];
auto& pl = players_[player];
if (verbose_) {
pl->notify(fmt::format("Type new {} for {} card(s), separated by spaces.", "UNKNOWN_COUNTER", count));
}
......@@ -4943,7 +4944,7 @@ private:
legal_actions_.push_back(LegalAction::number(number));
}
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
std::string str = "Select a number, one of:";
pl->notify(str);
for (int i = 0; i < count; ++i) {
......@@ -4974,7 +4975,7 @@ private:
}
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
pl->notify("Select " + std::to_string(count) +
" attributes separated by spaces:");
for (int i = 0; i < attrs.size(); i++) {
......@@ -5002,7 +5003,7 @@ private:
CardId cid = c_get_card_id(code);
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
auto card = c_get_card(code);
pl->notify("Select position for " + card.name_ + ":");
}
......@@ -5016,7 +5017,7 @@ private:
legal_actions_.push_back(la);
int cmd_idx = legal_actions_.size();
if (verbose_) {
auto pl = players_[player];
auto& pl = players_[player];
pl->notify(fmt::format("{}: {}", cmd_idx, position_to_string(pos)));
}
}
......@@ -5040,7 +5041,7 @@ private:
void _damage(uint8_t player, uint32_t amount) {
lp_[player] -= amount;
if (verbose_) {
auto lp = players_[player];
auto& lp = players_[player];
lp->notify(fmt::format("Your lp decreased by {}, now {}", amount, lp_[player]));
players_[1 - player]->notify(fmt::format("{}'s lp decreased by {}, now {}",
lp->nickname_, amount, lp_[player]));
......@@ -5050,7 +5051,7 @@ private:
void _recover(uint8_t player, uint32_t amount) {
lp_[player] += amount;
if (verbose_) {
auto lp = players_[player];
auto& lp = players_[player];
lp->notify(fmt::format("Your lp increased by {}, now {}", amount, lp_[player]));
players_[1 - player]->notify(fmt::format("{}'s lp increased by {}, now {}",
lp->nickname_, amount, lp_[player]));
......@@ -5074,34 +5075,113 @@ private:
class YGOProEnv : public Env<YGOProEnvSpec> {
protected:
const int max_episode_steps_;
const int timeout_;
int elapsed_step_;
std::uniform_int_distribution<uint64_t> dist_int_;
YGOProEnvImpl env_impl_;
// The pool can't be in vector, so we create multiple pools manually
BS::thread_pool pool0_;
BS::thread_pool pool1_;
BS::thread_pool pool2_;
BS::thread_pool pool3_;
BS::thread_pool pool4_;
const int max_timeout_{5};
// YGOProEnvImpl env_impl0_;
// YGOProEnvImpl env_impl1_;
// YGOProEnvImpl env_impl2_;
// YGOProEnvImpl env_impl3_;
// YGOProEnvImpl env_impl4_;
std::vector<YGOProEnvImpl> env_impls_;
bool done_{true};
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
env_impl_(YGOProEnvImpl(spec, dist_int_(gen_))) {
elapsed_step_(max_episode_steps_ + 1),
timeout_(spec.config["timeout"_]),
pool0_(1), pool1_(1), pool2_(1), pool3_(1), pool4_(1),
dist_int_(0, 0xffffffff) {
env_impls_.reserve(max_timeout_);
env_impls_.emplace_back(spec, dist_int_(gen_));
}
bool IsDone() override { return done_; }
BS::thread_pool& get_pool(int idx) {
switch (idx) {
case 0: return pool0_;
case 1: return pool1_;
case 2: return pool2_;
case 3: return pool3_;
case 4: return pool4_;
default: throw std::runtime_error("Invalid pool index");
}
}
bool IsDone() override { return env_impl_.done(); }
void handle_timeout() {
env_impls_.emplace_back(spec_, dist_int_(gen_));
if (env_impls_.capacity() > max_timeout_) {
throw std::runtime_error("Too many timeouts");
}
done_ = true;
State state = Allocate();
state["reward"_] = 1.0;
state["info:to_play"_] = 1;
state["info:is_selfplay"_] = 1;
state["info:win_reason"_] = 1;
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
}
void Reset() override {
env_impl_.reset();
elapsed_step_ = 0;
int idx = env_impls_.size() - 1;
auto& pool = get_pool(idx);
auto fut = pool.submit_task([this, idx]() {
env_impls_[idx].reset();
});
if (fut.wait_for(std::chrono::seconds(timeout_)) != std::future_status::ready) {
throw std::runtime_error("Reset timeout");
}
auto &env_impl = env_impls_[idx];
elapsed_step_ = 0;
done_ = false;
State state = Allocate();
env_impl_.WriteState(state, 0.0);
env_impl.WriteState(state);
}
void Step(const Action &action) override {
auto [reward, win_reason] = env_impl_.step(action["action"_]);
State state = Allocate();
env_impl_.WriteState(state, reward, win_reason);
int idx = env_impls_.size() - 1;
auto& pool = get_pool(idx);
int action_idx = action["action"_];
pool.detach_task([this, action_idx, idx]() {
// Test timeout: random sleep with probability 0.01
// if (dist_int_(gen_) % 10000 == 0) {
// fmt::println("Env {} sleep {}", env_id_, env_impls_.capacity());
// std::this_thread::sleep_for(std::chrono::seconds(5));
// fmt::println("Env {} after {}", env_id_, env_impls_.capacity());
// auto& env_impl = env_impls_[idx];
// env_impl.step(action_idx);
// std::this_thread::sleep_for(std::chrono::seconds(1));
// return;
// }
env_impls_[idx].step(action_idx);
});
if (!pool.wait_for(std::chrono::seconds(timeout_))) {
handle_timeout();
fmt::println("Env {} timeout, new env created", env_id_);
} else {
auto& env_impl = env_impls_[idx];
done_ = env_impl.ret_reward_ != 0;
State state = Allocate();
env_impl.WriteState(state);
}
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment