Commit b7e43382 authored by sbl1996@126.com's avatar sbl1996@126.com

Seperate implementation of YGOProEnv

parent 112f1bc6
......@@ -1062,7 +1062,7 @@ public:
};
class Card {
friend class YGOProEnv;
friend class YGOProEnvImpl;
protected:
CardCode code_ = 0;
......@@ -1434,7 +1434,7 @@ inline std::string getline() {
}
class Player {
friend class YGOProEnv;
friend class YGOProEnvImpl;
protected:
const std::string nickname_;
......@@ -1590,14 +1590,17 @@ constexpr int32_t rules_ = 5;
constexpr int32_t duel_options_ = ((rules_ & 0xFF) << 16) + (0 & 0xFFFF);
class YGOProEnv : public Env<YGOProEnvSpec> {
class YGOProEnvImpl {
protected:
const EnvSpec<YGOProEnvFns> spec_;
constexpr static int init_lp_ = 8000;
constexpr static int startcount_ = 5;
constexpr static int drawcount_ = 1;
std::string deck1_;
std::string deck2_;
const std::string deck1_;
const std::string deck2_;
std::vector<uint32> main_deck0_;
std::vector<uint32> main_deck1_;
std::vector<uint32> extra_deck0_;
......@@ -1613,9 +1616,7 @@ protected:
const int player_;
PlayMode play_mode_;
bool verbose_ = false;
int max_episode_steps_, elapsed_step_;
const bool verbose_ = false;
PlayerId ai_player_;
......@@ -1629,7 +1630,7 @@ protected:
PlayerId winner_;
uint8_t win_reason_;
bool greedy_reward_;
const bool greedy_reward_;
int lp_[2];
......@@ -1657,18 +1658,18 @@ protected:
// chain
PlayerId chaining_player_;
double step_time_ = 0;
uint64_t step_time_count_ = 0;
// double step_time_ = 0;
// uint64_t step_time_count_ = 0;
double reset_time_ = 0;
double reset_time_1_ = 0;
double reset_time_2_ = 0;
double reset_time_3_ = 0;
uint64_t reset_time_count_ = 0;
// double reset_time_ = 0;
// double reset_time_1_ = 0;
// double reset_time_2_ = 0;
// double reset_time_3_ = 0;
// uint64_t reset_time_count_ = 0;
// average time for decks
ankerl::unordered_dense::map<std::string, double> deck_time_;
ankerl::unordered_dense::map<std::string, uint64_t> deck_time_count_;
// // average time for decks
// ankerl::unordered_dense::map<std::string, double> deck_time_;
// ankerl::unordered_dense::map<std::string, uint64_t> deck_time_count_;
const int n_history_actions_;
......@@ -1702,6 +1703,8 @@ protected:
// MSG_SELECT_COUNTER
int n_counters_ = 0;
std::mt19937 gen_;
// async reset
const bool async_reset_;
int n_lives_ = 0;
......@@ -1711,10 +1714,8 @@ protected:
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
YGOProEnvImpl(const EnvSpec<YGOProEnvFns> &spec, uint32_t env_seed)
: spec_(spec), dist_int_(0, 0xffffffff),
deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]),
player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])),
......@@ -1728,15 +1729,20 @@ public:
}
// fmt::println("env_id: {}, seed: {}, x: {}", env_id_, seed_, dist_int_(gen_));
gen_ = std::mt19937(env_seed);
duel_gen_ = std::mt19937(dist_int_(gen_));
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, duel_seed=dist_int_(gen_)] {
return new_duel(duel_seed);
});
fmt::println("Async reset is deprecated!!!");
}
// if (async_reset_) {
// duel_fut_ = pool_.submit_task([
// this, duel_seed=dist_int_(gen_)] {
// return new_duel(duel_seed);
// });
// }
int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
history_actions_1_ = TArray<uint8_t>(Array(
......@@ -1745,7 +1751,7 @@ public:
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
}
~YGOProEnv() {
~YGOProEnvImpl() {
for (int i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
delete players_[i];
......@@ -1757,7 +1763,7 @@ public:
int max_cards() const { return spec_.config["max_cards"_]; }
bool IsDone() override { return done_; }
bool done() const { return done_; }
bool random_mode() const { return play_modes_.size() > 1; }
......@@ -1801,8 +1807,8 @@ public:
return mduel;
}
void Reset() override {
clock_t start = clock();
void reset() {
// clock_t start = clock();
if (random_mode()) {
play_mode_ = play_modes_[dist_int_(gen_) % play_modes_.size()];
......@@ -1826,16 +1832,20 @@ public:
ha_p_1_ = 0;
ha_p_2_ = 0;
clock_t _start = clock();
// clock_t _start = clock();
intptr_t old_duel = pduel_;
MDuel mduel;
if (async_reset_) {
mduel = duel_fut_.get();
n_lives_ = 1;
} else {
mduel = new_duel(dist_int_(gen_));
if (duel_started_) {
YGO_EndDuel(pduel_);
}
MDuel mduel;
// if (async_reset_) {
// mduel = duel_fut_.get();
// n_lives_ = 1;
// } else {
// mduel = new_duel(dist_int_(gen_));
// }
mduel = new_duel(dist_int_(gen_));
auto duel_seed = mduel.seed;
pduel_ = mduel.pduel;
......@@ -1933,27 +1943,27 @@ public:
}
duel_started_ = true;
eng_flag_ = 0;
winner_ = 255;
win_reason_ = 255;
discard_hand_ = false;
done_ = false;
// update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock();
next();
done_ = false;
elapsed_step_ = 0;
WriteState(0.0);
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, old_duel, duel_seed=dist_int_(gen_)] {
if (old_duel != 0) {
YGO_EndDuel(old_duel);
}
return new_duel(duel_seed);
});
}
// if (async_reset_) {
// duel_fut_ = pool_.submit_task([
// this, old_duel, duel_seed=dist_int_(gen_)] {
// if (old_duel != 0) {
// YGO_EndDuel(old_duel);
// }
// return new_duel(duel_seed);
// });
// }
// update_time_stat(_start, reset_time_count_, reset_time_3_);
// update_time_stat(start, reset_time_count_, reset_time_);
......@@ -2200,10 +2210,7 @@ public:
}
}
void Step(const Action &action) override {
clock_t start = clock();
int idx = action["action"_];
std::tuple<float, int> step(int idx) {
callback_(idx);
update_history_actions(to_play_, legal_actions_[idx]);
......@@ -2275,17 +2282,15 @@ public:
}
update_time_stat(start, step_time_count_, step_time_);
step_time_count_++;
double step_time = 0;
if (done_) {
step_time = step_time_;
step_time_ = 0;
step_time_count_ = 0;
}
// update_time_stat(start, step_time_count_, step_time_);
// step_time_count_++;
WriteState(reward, win_reason_, step_time);
// double step_time = 0;
// if (done_) {
// step_time = step_time_;
// step_time_ = 0;
// step_time_count_ = 0;
// }
// if (done_) {
// update_time_stat(deck_name_[0], step_time_);
......@@ -2296,6 +2301,82 @@ public:
// if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000);
// }
return {reward, reason};
}
using YGOProEnvSpec = EnvSpec<YGOProEnvFns>;
using State =
Dict<typename YGOProEnvSpec::StateKeys,
typename SpecToTArray<typename YGOProEnvSpec::StateSpec::Values>::Type>;
void WriteState(
State &state, float reward, int win_reason = 0, double step_time = 0.0) {
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
return;
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
legal_actions_.resize(max_options());
}
n_options = legal_actions_.size();
state["info:num_options"_] = n_options;
for (int i = 0; i < n_options; ++i) {
auto &action = legal_actions_[i];
action.msg_ = msg_;
const auto &spec = action.spec_;
if (!spec.empty()) {
const auto& spec_info = find_spec_info(spec_infos, spec);
action.spec_index_ = spec_info.index;
if (action.cid_ == 0) {
action.cid_ = spec_info.cid;
}
}
}
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions
auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break;
}
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
}
}
private:
......@@ -2701,77 +2782,6 @@ private:
// ygopro-core API
void WriteState(float reward, int win_reason = 0, double step_time = 0.0) {
State state = Allocate();
int n_options = legal_actions_.size();
state["reward"_] = reward;
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
return;
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback
if (n_options > max_options()) {
legal_actions_.resize(max_options());
}
n_options = legal_actions_.size();
state["info:num_options"_] = n_options;
for (int i = 0; i < n_options; ++i) {
auto &action = legal_actions_[i];
action.msg_ = msg_;
const auto &spec = action.spec_;
if (!spec.empty()) {
const auto& spec_info = find_spec_info(spec_infos, spec);
action.spec_index_ = spec_info.index;
if (action.cid_ == 0) {
action.cid_ = spec_info.cid;
}
}
}
_set_obs_actions(state["obs:actions_"_], legal_actions_);
// write history actions
auto ha_p = to_play_ == 0 ? ha_p_1_ : ha_p_2_;
auto &history_actions = to_play_ == 0 ? history_actions_1_ : history_actions_2_;
int offset = n_history_actions_ - ha_p;
int n_h_action_feats = history_actions.Shape()[1];
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions[ha_p].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions.Data(), n_h_action_feats * ha_p);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 3)) == 0) {
break;
}
// state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 12)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 12)));
state["obs:h_actions_"_](i, 12) = static_cast<uint8_t>(turn_diff);
}
}
void show_decision(int idx) {
std::string s;
const auto& a = legal_actions_[idx];
......@@ -5050,16 +5060,52 @@ private:
void _duel_end(uint8_t player, uint8_t reason) {
winner_ = player;
win_reason_ = reason;
if (async_reset_) {
n_lives_--;
} else {
YGO_EndDuel(pduel_);
}
// if (async_reset_) {
// n_lives_--;
// } else {
// YGO_EndDuel(pduel_);
// }
YGO_EndDuel(pduel_);
duel_started_ = false;
}
};
class YGOProEnv : public Env<YGOProEnvSpec> {
protected:
const int max_episode_steps_;
int elapsed_step_;
std::uniform_int_distribution<uint64_t> dist_int_;
YGOProEnvImpl env_impl_;
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
env_impl_(YGOProEnvImpl(spec, dist_int_(gen_))) {
}
bool IsDone() override { return env_impl_.done(); }
void Reset() override {
env_impl_.reset();
elapsed_step_ = 0;
State state = Allocate();
env_impl_.WriteState(state, 0.0);
}
void Step(const Action &action) override {
auto [reward, win_reason] = env_impl_.step(action["action"_]);
State state = Allocate();
env_impl_.WriteState(state, reward, win_reason);
}
};
using YGOProEnvPool = AsyncEnvPool<YGOProEnv>;
} // namespace ygopro
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment