Commit b7e43382 authored by sbl1996@126.com's avatar sbl1996@126.com

Seperate implementation of YGOProEnv

parent 112f1bc6
...@@ -1062,7 +1062,7 @@ public: ...@@ -1062,7 +1062,7 @@ public:
}; };
class Card { class Card {
friend class YGOProEnv; friend class YGOProEnvImpl;
protected: protected:
CardCode code_ = 0; CardCode code_ = 0;
...@@ -1434,7 +1434,7 @@ inline std::string getline() { ...@@ -1434,7 +1434,7 @@ inline std::string getline() {
} }
class Player { class Player {
friend class YGOProEnv; friend class YGOProEnvImpl;
protected: protected:
const std::string nickname_; const std::string nickname_;
...@@ -1590,14 +1590,17 @@ constexpr int32_t rules_ = 5; ...@@ -1590,14 +1590,17 @@ constexpr int32_t rules_ = 5;
constexpr int32_t duel_options_ = ((rules_ & 0xFF) << 16) + (0 & 0xFFFF); constexpr int32_t duel_options_ = ((rules_ & 0xFF) << 16) + (0 & 0xFFFF);
class YGOProEnv : public Env<YGOProEnvSpec> { class YGOProEnvImpl {
protected: protected:
const EnvSpec<YGOProEnvFns> spec_;
constexpr static int init_lp_ = 8000; constexpr static int init_lp_ = 8000;
constexpr static int startcount_ = 5; constexpr static int startcount_ = 5;
constexpr static int drawcount_ = 1; constexpr static int drawcount_ = 1;
std::string deck1_; const std::string deck1_;
std::string deck2_; const std::string deck2_;
std::vector<uint32> main_deck0_; std::vector<uint32> main_deck0_;
std::vector<uint32> main_deck1_; std::vector<uint32> main_deck1_;
std::vector<uint32> extra_deck0_; std::vector<uint32> extra_deck0_;
...@@ -1613,9 +1616,7 @@ protected: ...@@ -1613,9 +1616,7 @@ protected:
const int player_; const int player_;
PlayMode play_mode_; PlayMode play_mode_;
bool verbose_ = false; const bool verbose_ = false;
int max_episode_steps_, elapsed_step_;
PlayerId ai_player_; PlayerId ai_player_;
...@@ -1629,7 +1630,7 @@ protected: ...@@ -1629,7 +1630,7 @@ protected:
PlayerId winner_; PlayerId winner_;
uint8_t win_reason_; uint8_t win_reason_;
bool greedy_reward_; const bool greedy_reward_;
int lp_[2]; int lp_[2];
...@@ -1657,18 +1658,18 @@ protected: ...@@ -1657,18 +1658,18 @@ protected:
// chain // chain
PlayerId chaining_player_; PlayerId chaining_player_;
double step_time_ = 0; // double step_time_ = 0;
uint64_t step_time_count_ = 0; // uint64_t step_time_count_ = 0;
double reset_time_ = 0; // double reset_time_ = 0;
double reset_time_1_ = 0; // double reset_time_1_ = 0;
double reset_time_2_ = 0; // double reset_time_2_ = 0;
double reset_time_3_ = 0; // double reset_time_3_ = 0;
uint64_t reset_time_count_ = 0; // uint64_t reset_time_count_ = 0;
// average time for decks // // average time for decks
ankerl::unordered_dense::map<std::string, double> deck_time_; // ankerl::unordered_dense::map<std::string, double> deck_time_;
ankerl::unordered_dense::map<std::string, uint64_t> deck_time_count_; // ankerl::unordered_dense::map<std::string, uint64_t> deck_time_count_;
const int n_history_actions_; const int n_history_actions_;
...@@ -1702,6 +1703,8 @@ protected: ...@@ -1702,6 +1703,8 @@ protected:
// MSG_SELECT_COUNTER // MSG_SELECT_COUNTER
int n_counters_ = 0; int n_counters_ = 0;
std::mt19937 gen_;
// async reset // async reset
const bool async_reset_; const bool async_reset_;
int n_lives_ = 0; int n_lives_ = 0;
...@@ -1711,10 +1714,8 @@ protected: ...@@ -1711,10 +1714,8 @@ protected:
public: public:
YGOProEnv(const Spec &spec, int env_id) YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint32_t env_seed)
: Env<YGOProEnvSpec>(spec, env_id), : spec_(spec), dist_int_(0, 0xffffffff),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]), deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]),
player_(spec.config["player"_]), players_{nullptr, nullptr}, player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])), play_modes_(parse_play_modes(spec.config["play_mode"_])),
...@@ -1728,15 +1729,20 @@ public: ...@@ -1728,15 +1729,20 @@ public:
} }
// fmt::println("env_id: {}, seed: {}, x: {}", env_id_, seed_, dist_int_(gen_)); // fmt::println("env_id: {}, seed: {}, x: {}", env_id_, seed_, dist_int_(gen_));
gen_ = std::mt19937(env_seed);
duel_gen_ = std::mt19937(dist_int_(gen_)); duel_gen_ = std::mt19937(dist_int_(gen_));
if (async_reset_) { if (async_reset_) {
duel_fut_ = pool_.submit_task([ fmt::println("Async reset is deprecated!!!");
this, duel_seed=dist_int_(gen_)] {
return new_duel(duel_seed);
});
} }
// if (async_reset_) {
// duel_fut_ = pool_.submit_task([
// this, duel_seed=dist_int_(gen_)] {
// return new_duel(duel_seed);
// });
// }
int max_options = spec.config["max_options"_]; int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1]; int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
history_actions_1_ = TArray<uint8_t>(Array( history_actions_1_ = TArray<uint8_t>(Array(
...@@ -1745,7 +1751,7 @@ public: ...@@ -1745,7 +1751,7 @@ public:
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2}))); ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
} }
~YGOProEnv() { ~YGOProEnvImpl() {
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
if (players_[i] != nullptr) { if (players_[i] != nullptr) {
delete players_[i]; delete players_[i];
...@@ -1757,7 +1763,7 @@ public: ...@@ -1757,7 +1763,7 @@ public:
int max_cards() const { return spec_.config["max_cards"_]; } int max_cards() const { return spec_.config["max_cards"_]; }
bool IsDone() override { return done_; } bool done() const { return done_; }
bool random_mode() const { return play_modes_.size() > 1; } bool random_mode() const { return play_modes_.size() > 1; }
...@@ -1801,8 +1807,8 @@ public: ...@@ -1801,8 +1807,8 @@ public:
return mduel; return mduel;
} }
void Reset() override { void reset() {
clock_t start = clock(); // clock_t start = clock();
if (random_mode()) { if (random_mode()) {
play_mode_ = play_modes_[dist_int_(gen_) % play_modes_.size()]; play_mode_ = play_modes_[dist_int_(gen_) % play_modes_.size()];
...@@ -1826,16 +1832,20 @@ public: ...@@ -1826,16 +1832,20 @@ public:
ha_p_1_ = 0; ha_p_1_ = 0;
ha_p_2_ = 0; ha_p_2_ = 0;
clock_t _start = clock(); // clock_t _start = clock();
intptr_t old_duel = pduel_; intptr_t old_duel = pduel_;
MDuel mduel; if (duel_started_) {
if (async_reset_) { YGO_EndDuel(pduel_);
mduel = duel_fut_.get();
n_lives_ = 1;
} else {
mduel = new_duel(dist_int_(gen_));
} }
MDuel mduel;
// if (async_reset_) {
// mduel = duel_fut_.get();
// n_lives_ = 1;
// } else {
// mduel = new_duel(dist_int_(gen_));
// }
mduel = new_duel(dist_int_(gen_));
auto duel_seed = mduel.seed; auto duel_seed = mduel.seed;
pduel_ = mduel.pduel; pduel_ = mduel.pduel;
...@@ -1933,27 +1943,27 @@ public: ...@@ -1933,27 +1943,27 @@ public:
} }
duel_started_ = true; duel_started_ = true;
eng_flag_ = 0;
winner_ = 255; winner_ = 255;
win_reason_ = 255; win_reason_ = 255;
discard_hand_ = false;
done_ = false;
// update_time_stat(_start, reset_time_count_, reset_time_2_); // update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock(); // _start = clock();
next(); next();
done_ = false; // if (async_reset_) {
elapsed_step_ = 0; // duel_fut_ = pool_.submit_task([
WriteState(0.0); // this, old_duel, duel_seed=dist_int_(gen_)] {
// if (old_duel != 0) {
if (async_reset_) { // YGO_EndDuel(old_duel);
duel_fut_ = pool_.submit_task([ // }
this, old_duel, duel_seed=dist_int_(gen_)] { // return new_duel(duel_seed);
if (old_duel != 0) { // });
YGO_EndDuel(old_duel); // }
}
return new_duel(duel_seed);
});
}
// update_time_stat(_start, reset_time_count_, reset_time_3_); // update_time_stat(_start, reset_time_count_, reset_time_3_);
// update_time_stat(start, reset_time_count_, reset_time_); // update_time_stat(start, reset_time_count_, reset_time_);
...@@ -2200,10 +2210,7 @@ public: ...@@ -2200,10 +2210,7 @@ public:
} }
} }
void Step(const Action &action) override { std::tuple<float, int> step(int idx) {
clock_t start = clock();
int idx = action["action"_];
callback_(idx); callback_(idx);
update_history_actions(to_play_, legal_actions_[idx]); update_history_actions(to_play_, legal_actions_[idx]);
...@@ -2275,17 +2282,15 @@ public: ...@@ -2275,17 +2282,15 @@ public:
} }
update_time_stat(start, step_time_count_, step_time_); // update_time_stat(start, step_time_count_, step_time_);
step_time_count_++; // step_time_count_++;
double step_time = 0;
if (done_) {
step_time = step_time_;
step_time_ = 0;
step_time_count_ = 0;
}
WriteState(reward, win_reason_, step_time); // double step_time = 0;
// if (done_) {
// step_time = step_time_;
// step_time_ = 0;
// step_time_count_ = 0;
// }
// if (done_) { // if (done_) {
// update_time_stat(deck_name_[0], step_time_); // update_time_stat(deck_name_[0], step_time_);
...@@ -2296,6 +2301,82 @@ public: ...@@ -2296,6 +2301,82 @@ public:
// if (step_time_count_ % 3000 == 0) { // if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000); // fmt::println("Step time: {:.3f}", step_time_ * 1000);
// } // }
return {reward, reason};
}
using YGOProEnvSpec = EnvSpec<YGOProEnvFns>;
using State =
Dict<typename YGOProEnvSpec::StateKeys,
typename SpecToTArray<typename YGOProEnvSpec::StateSpec::Values>::Type>;
void WriteState(
State &state, float reward, int win_reason = 0, double step_time = 0.0) {
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
return;
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
legal_actions_.resize(max_options());
}
n_options = legal_actions_.size();
state["info:num_options"_] = n_options;
for (int i = 0; i < n_options; ++i) {
auto &action = legal_actions_[i];
action.msg_ = msg_;
const auto &spec = action.spec_;
if (!spec.empty()) {
const auto& spec_info = find_spec_info(spec_infos, spec);
action.spec_index_ = spec_info.index;
if (action.cid_ == 0) {
action.cid_ = spec_info.cid;
}
}
}
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions
auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break;
}
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
}
} }
private: private:
...@@ -2701,77 +2782,6 @@ private: ...@@ -2701,77 +2782,6 @@ private:
// ygopro-core API // ygopro-core API
void WriteState(float reward, int win_reason = 0, double step_time = 0.0) {
State state = Allocate();
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
return;
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
legal_actions_.resize(max_options());
}
n_options = legal_actions_.size();
state["info:num_options"_] = n_options;
for (int i = 0; i < n_options; ++i) {
auto &action = legal_actions_[i];
action.msg_ = msg_;
const auto &spec = action.spec_;
if (!spec.empty()) {
const auto& spec_info = find_spec_info(spec_infos, spec);
action.spec_index_ = spec_info.index;
if (action.cid_ == 0) {
action.cid_ = spec_info.cid;
}
}
}
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions
auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break;
}
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
}
}
void show_decision(int idx) { void show_decision(int idx) {
std::string s; std::string s;
const auto& a = legal_actions_[idx]; const auto& a = legal_actions_[idx];
...@@ -5050,16 +5060,52 @@ private: ...@@ -5050,16 +5060,52 @@ private:
void _duel_end(uint8_t player, uint8_t reason) { void _duel_end(uint8_t player, uint8_t reason) {
winner_ = player; winner_ = player;
win_reason_ = reason; win_reason_ = reason;
if (async_reset_) { // if (async_reset_) {
n_lives_--; // n_lives_--;
} else { // } else {
YGO_EndDuel(pduel_); // YGO_EndDuel(pduel_);
} // }
YGO_EndDuel(pduel_);
duel_started_ = false; duel_started_ = false;
} }
}; };
class YGOProEnv : public Env<YGOProEnvSpec> {
protected:
const int max_episode_steps_;
int elapsed_step_;
std::uniform_int_distribution<uint64_t> dist_int_;
YGOProEnvImpl env_impl_;
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
env_impl_(YGOProEnvImpl(spec, dist_int_(gen_))) {
}
bool IsDone() override { return env_impl_.done(); }
void Reset() override {
env_impl_.reset();
elapsed_step_ = 0;
State state = Allocate();
env_impl_.WriteState(state, 0.0);
}
void Step(const Action &action) override {
auto [reward, win_reason] = env_impl_.step(action["action"_]);
State state = Allocate();
env_impl_.WriteState(state, reward, win_reason);
}
};
using YGOProEnvPool = AsyncEnvPool<YGOProEnv>; using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
} // namespace ygopro } // namespace ygopro
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment