Commit e79a43ef authored by sbl1996@126.com's avatar sbl1996@126.com

Use both history actions (cheat)

parent 892c7364
......@@ -18,7 +18,7 @@ class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
......@@ -26,7 +26,6 @@ class ActionEncoder(nn.Module):
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
......@@ -38,9 +37,9 @@ class ActionEncoder(nn.Module):
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1)
xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
......@@ -169,7 +168,8 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
......@@ -216,7 +216,13 @@ class Encoder(nn.Module):
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
......@@ -240,7 +246,7 @@ class Encoder(nn.Module):
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = action_encoder(x_actions[..., 2:])
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
......
......@@ -1263,7 +1263,7 @@ public:
"obs:actions_"_.Bind(
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats})),
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
......@@ -1391,15 +1391,10 @@ protected:
const int n_history_actions_;
// circular buffer for history actions of player 0
TArray<uint8_t> history_actions_0_;
int ha_p_0_ = 0;
std::vector<CardId> h_card_ids_0_;
// circular buffer for history actions of player 1
TArray<uint8_t> history_actions_1_;
int ha_p_1_ = 0;
std::vector<CardId> h_card_ids_1_;
// circular buffer for history actions
TArray<uint8_t> history_actions_;
int ha_p_ = 0;
std::vector<CardId> h_card_ids_;
std::unordered_set<std::string> revealed_;
......@@ -1461,12 +1456,9 @@ public:
int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
h_card_ids_0_.resize(max_options);
h_card_ids_1_.resize(max_options);
history_actions_0_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats})));
history_actions_1_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats})));
h_card_ids_.resize(max_options);
history_actions_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
}
~YGOProEnv() {
......@@ -1537,10 +1529,8 @@ public:
turn_count_ = 0;
ms_idx_ = -1;
history_actions_0_.Zero();
history_actions_1_.Zero();
ha_p_0_ = 0;
ha_p_1_ = 0;
history_actions_.Zero();
ha_p_ = 0;
clock_t _start = clock();
......@@ -1803,23 +1793,22 @@ public:
}
void update_h_card_ids(PlayerId player, int idx) {
auto &h_card_ids = player == 0 ? h_card_ids_0_ : h_card_ids_1_;
h_card_ids[idx] = parse_card_id(options_[idx], player);
h_card_ids_[idx] = parse_card_id(options_[idx], player);
}
void update_history_actions(PlayerId player, int idx) {
auto &history_actions =
player == 0 ? history_actions_0_ : history_actions_1_;
auto &ha_p = player == 0 ? ha_p_0_ : ha_p_1_;
const auto &h_card_ids = player == 0 ? h_card_ids_0_ : h_card_ids_1_;
ha_p--;
if (ha_p < 0) {
ha_p = n_history_actions_ - 1;
if ((msg_ == MSG_SELECT_CHAIN) & (options_[idx][0] == 'c')) {
return;
}
history_actions[ha_p].Zero();
_set_obs_action(history_actions, ha_p, msg_, options_[idx], {},
h_card_ids[idx]);
ha_p_--;
if (ha_p_ < 0) {
ha_p_ = n_history_actions_ - 1;
}
history_actions_[ha_p_].Zero();
_set_obs_action(history_actions_, ha_p_, msg_, options_[idx], {},
h_card_ids_[idx]);
history_actions_[ha_p_](13) = static_cast<uint8_t>(player);
history_actions_[ha_p_](14) = static_cast<uint8_t>(turn_count_);
}
void show_deck(const std::vector<CardCode> &deck, const std::string &prefix) const {
......@@ -1849,7 +1838,7 @@ public:
}
void show_history_actions(PlayerId player) const {
const auto &ha = player == 0 ? history_actions_0_ : history_actions_1_;
const auto &ha = history_actions_;
// print card ids of history actions
for (int i = 0; i < n_history_actions_; ++i) {
fmt::print("history {}\n", i);
......@@ -2064,7 +2053,7 @@ private:
feat(2) = op_lp_1;
feat(3) = op_lp_2;
feat(4) = std::min(turn_count_, 8);
feat(4) = std::min(turn_count_, 16);
feat(5) = phase2id.at(current_phase_);
feat(6) = (me == 0) ? 1 : 0;
feat(7) = (me == tp_) ? 1 : 0;
......@@ -2407,34 +2396,38 @@ private:
n_options = options_.size();
state["info:num_options"_] = n_options;
// update h_card_ids from state
auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_;
// update_h_card_ids from state
for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2);
if (spec_index == 0) {
h_card_ids[i] = 0;
h_card_ids_[i] = 0;
} else {
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
h_card_ids[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
h_card_ids_[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
}
}
// write history actions
const auto &ha_p = to_play_ == 0 ? ha_p_0_ : ha_p_1_;
const auto &history_actions =
to_play_ == 0 ? history_actions_0_ : history_actions_1_;
int n1 = n_history_actions_ - ha_p;
int n_action_feats = state["obs:actions_"_].Shape()[1];
int offset = n_history_actions_ - ha_p_;
int n_h_action_feats = history_actions_.Shape()[1];
state["obs:h_actions_"_].Assign((uint8_t *)history_actions[ha_p].Data(),
n_action_feats * n1);
state["obs:h_actions_"_][n1].Assign((uint8_t *)history_actions.Data(),
n_action_feats * ha_p);
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions_[ha_p_].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions_.Data(), n_h_action_feats * ha_p_);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 2)) == 0) {
break;
}
state["obs:h_actions_"_](i, 13) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 13)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 14)));
state["obs:h_actions_"_](i, 14) = static_cast<uint8_t>(turn_diff);
}
}
void show_decision(int idx) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment