Commit 77492ca0 authored by sbl1996@126.com's avatar sbl1996@126.com

Add oppo_info

parent 3dfee5f5
...@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
greedy_reward=args.greedy_reward if not eval else True, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
timeout=args.timeout, timeout=args.timeout,
oppo_info=args.m2.oppo_info if eval else args.m1.oppo_info,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
return envs return envs
......
This diff is collapsed.
...@@ -1527,7 +1527,8 @@ public: ...@@ -1527,7 +1527,8 @@ public:
"verbose"_.Bind(false), "max_options"_.Bind(16), "verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16), "max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false), "record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600)); "greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false));
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
...@@ -1539,6 +1540,7 @@ public: ...@@ -1539,6 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})), Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind( "obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})), Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:g_cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})), "info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})), "info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})), "info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
...@@ -2259,8 +2261,13 @@ public: ...@@ -2259,8 +2261,13 @@ public:
} }
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // if (spec_.config["oppo_info"_]) {
reward = winner_ == player ? base_reward : -base_reward; if (false) {
reward = winner_ == 0 ? base_reward : -base_reward;
} else {
// to_play_ is the previous player
reward = winner_ == player ? base_reward : -base_reward;
}
} else { } else {
reward = winner_ == ai_player_ ? base_reward : -base_reward; reward = winner_ == ai_player_ ? base_reward : -base_reward;
} }
...@@ -2331,6 +2338,9 @@ public: ...@@ -2331,6 +2338,9 @@ public:
} }
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_); auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
if (spec_.config["oppo_info"_]) {
_set_obs_g_cards(state["obs:g_cards_"_]);
}
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards); _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
...@@ -2438,8 +2448,30 @@ private: ...@@ -2438,8 +2448,30 @@ private:
return {spec_infos, loc_n_cards}; return {spec_infos, loc_n_cards};
} }
void _set_obs_g_cards(TArray<uint8_t> &f_cards) {
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA,
};
for (auto location : configs) {
std::vector<Card> cards = get_cards_in_location(pi, location);
int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++;
}
}
}
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c, void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0) { bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards // check offset exceeds max_cards
uint8_t location = c.location_; uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY; bool overlay = location & LOCATION_OVERLAY;
...@@ -2462,7 +2494,7 @@ private: ...@@ -2462,7 +2494,7 @@ private:
seq = c.sequence_ + 1; seq = c.sequence_ + 1;
} }
f_cards(offset, 3) = seq; f_cards(offset, 3) = seq;
f_cards(offset, 4) = (c.controler_ != to_play_) ? 1 : 0; f_cards(offset, 4) = global ? c.controler_ : ((c.controler_ != to_play_) ? 1 : 0);
if (overlay) { if (overlay) {
f_cards(offset, 5) = position_to_id(POS_FACEUP); f_cards(offset, 5) = position_to_id(POS_FACEUP);
f_cards(offset, 6) = 1; f_cards(offset, 6) = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment