Commit 77492ca0 authored by sbl1996@126.com's avatar sbl1996@126.com

Add oppo_info

parent 3dfee5f5
......@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
timeout=args.timeout,
oppo_info=args.m2.oppo_info if eval else args.m1.oppo_info,
)
envs.num_envs = num_envs
return envs
......
This diff is collapsed.
......@@ -1527,7 +1527,8 @@ public:
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600));
"greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1539,6 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:g_cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
......@@ -2259,8 +2261,13 @@ public:
}
if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player
reward = winner_ == player ? base_reward : -base_reward;
// if (spec_.config["oppo_info"_]) {
if (false) {
reward = winner_ == 0 ? base_reward : -base_reward;
} else {
// to_play_ is the previous player
reward = winner_ == player ? base_reward : -base_reward;
}
} else {
reward = winner_ == ai_player_ ? base_reward : -base_reward;
}
......@@ -2331,6 +2338,9 @@ public:
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
if (spec_.config["oppo_info"_]) {
_set_obs_g_cards(state["obs:g_cards_"_]);
}
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
......@@ -2438,8 +2448,30 @@ private:
return {spec_infos, loc_n_cards};
}
void _set_obs_g_cards(TArray<uint8_t> &f_cards) {
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA,
};
for (auto location : configs) {
std::vector<Card> cards = get_cards_in_location(pi, location);
int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++;
}
}
}
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0) {
bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards
uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY;
......@@ -2462,7 +2494,7 @@ private:
seq = c.sequence_ + 1;
}
f_cards(offset, 3) = seq;
f_cards(offset, 4) = (c.controler_ != to_play_) ? 1 : 0;
f_cards(offset, 4) = global ? c.controler_ : ((c.controler_ != to_play_) ? 1 : 0);
if (overlay) {
f_cards(offset, 5) = position_to_id(POS_FACEUP);
f_cards(offset, 6) = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment