Commit 77492ca0 authored by sbl1996@126.com's avatar sbl1996@126.com

Add oppo_info

parent 3dfee5f5
...@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
greedy_reward=args.greedy_reward if not eval else True, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
timeout=args.timeout, timeout=args.timeout,
oppo_info=args.m2.oppo_info if eval else args.m1.oppo_info,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
return envs return envs
......
...@@ -113,7 +113,7 @@ class CardEncoder(nn.Module): ...@@ -113,7 +113,7 @@ class CardEncoder(nn.Module):
c_mask = x_loc == 0 c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False) c_mask = c_mask.at[:, 0].set(False)
x_owner = embed(2, c // 16)(x1[:, :, 2]) x_owner = embed(3, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3]) x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4]) x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5]) x_attribute = embed(8, c // 16)(x1[:, :, 5])
...@@ -208,6 +208,7 @@ class Encoder(nn.Module): ...@@ -208,6 +208,7 @@ class Encoder(nn.Module):
card_mask: bool = False card_mask: bool = False
noam: bool = False noam: bool = False
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False
version: int = 0 version: int = 0
@nn.compact @nn.compact
...@@ -227,28 +228,46 @@ class Encoder(nn.Module): ...@@ -227,28 +228,46 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype) fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
card_encoder = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1 ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls( action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards_g = x['g_cards_'] if self.oppo_info else None
x_cards = x['cards_'] x_cards = x['cards_']
x_global = x['global_'] x_global = x['global_']
x_actions = x['actions_'] x_actions = x['actions_']
x_h_actions = x['h_actions_'] x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0] batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0 valid = x_global[:, -1] == 0
n_cards = x_cards.shape[-2]
if self.oppo_info:
x_cards = jnp.concatenate([x_cards, x_cards_g], axis=-2)
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32)) x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id) x_id = id_embed(x_id)
if self.freeze_id: if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id) x_id = jax.lax.stop_gradient(x_id)
f_cards, c_mask = card_encoder(x_id, x_cards[:, :, 2:])
if self.oppo_info:
f_cards_me, f_cards_g = jnp.split(f_cards, [n_cards], axis=-2)
else:
f_cards_me, f_cards_g = f_cards, None
# Cards # Cards
f_cards, c_mask = CardEncoder( fs_g_card = []
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_id, x_cards[:, :, 2:]) for i, f_cards in enumerate([f_cards_g, f_cards_me]):
if f_cards is None:
fs_g_card.append(None)
continue
name = 'g_card_embed' if i == 0 else 'g_g_card_embed'
g_card_embed = self.param( g_card_embed = self.param(
'g_card_embed', name,
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02, lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype) (1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype) f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
...@@ -265,6 +284,8 @@ class Encoder(nn.Module): ...@@ -265,6 +284,8 @@ class Encoder(nn.Module):
f_cards, src_key_padding_mask=c_mask) f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards) f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0] f_g_card = f_cards[:, 0]
fs_g_card.append(f_g_card)
f_g_g_card, f_g_card = fs_g_card
# Global # Global
x_global = GlobalEncoder( x_global = GlobalEncoder(
...@@ -412,7 +433,8 @@ class Encoder(nn.Module): ...@@ -412,7 +433,8 @@ class Encoder(nn.Module):
else: else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
return f_actions, f_state, f_g_g_card, a_mask, valid
class Actor(nn.Module): class Actor(nn.Module):
...@@ -473,7 +495,29 @@ class Critic(nn.Module): ...@@ -473,7 +495,29 @@ class Critic(nn.Module):
return x return x
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main): class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, rstate1, rstate2, g_cards):
f_state = jnp.concatenate([rstate1[0], rstate1[1], rstate2[0], rstate2[0]], axis=-1)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=True)(f_state)
c = self.channels[-1]
t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
s, b = jnp.split(t, 2, axis=-1)
x = x * s + b
x = mlp([c], last_lin=False)(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False):
if main is not None: if main is not None:
rstate1, rstate2 = rstate rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2) rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
...@@ -484,10 +528,13 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main): ...@@ -484,10 +528,13 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
rstate = rstate1, rstate2 rstate = rstate1, rstate2
if done is not None: if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate) rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
if return_state:
return rstate, (f_state, rstate)
else:
return rstate, f_state return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True): def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True, return_state=False):
if switch: if switch:
def body_fn(cell, carry, x, done, switch): def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry rstate, init_rstate2 = carry
...@@ -497,7 +544,7 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True ...@@ -497,7 +544,7 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
return (rstate, init_rstate2), y return (rstate, init_rstate2), y
else: else:
def body_fn(cell, carry, x, done, main): def body_fn(cell, carry, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main) return rnn_step_by_main(cell, carry, x, done, main, return_state)
scan = nn.scan( scan = nn.scan(
body_fn, variable_broadcast='params', body_fn, variable_broadcast='params',
split_rngs={'params': False}) split_rngs={'params': False})
...@@ -531,6 +578,8 @@ class ModelArgs(EncoderArgs): ...@@ -531,6 +578,8 @@ class ModelArgs(EncoderArgs):
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = False
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
oppo_info: bool = False
"""whether to use opponent's information"""
rwkv_head_size: int = 32 rwkv_head_size: int = 32
"""the head size for the RWKV""" """the head size for the RWKV"""
...@@ -546,6 +595,7 @@ class RNNAgent(nn.Module): ...@@ -546,6 +595,7 @@ class RNNAgent(nn.Module):
noam: bool = False noam: bool = False
rwkv_head_size: int = 32 rwkv_head_size: int = 32
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False
version: int = 0 version: int = 0
switch: bool = True switch: bool = True
...@@ -557,6 +607,8 @@ class RNNAgent(nn.Module): ...@@ -557,6 +607,8 @@ class RNNAgent(nn.Module):
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None):
batch_size = jax.tree.leaves(rstate)[0].shape[0]
c = self.num_channels c = self.num_channels
oc = self.rnn_channels if self.rnn_type == 'rwkv' else None oc = self.rnn_channels if self.rnn_type == 'rwkv' else None
encoder = Encoder( encoder = Encoder(
...@@ -571,10 +623,11 @@ class RNNAgent(nn.Module): ...@@ -571,10 +623,11 @@ class RNNAgent(nn.Module):
card_mask=self.card_mask, card_mask=self.card_mask,
noam=self.noam, noam=self.noam,
action_feats=self.action_feats, action_feats=self.action_feats,
oppo_info=self.oppo_info,
version=self.version, version=self.version,
) )
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, f_g, mask, valid = encoder(x)
if self.rnn_type in ['lstm', 'none']: if self.rnn_type in ['lstm', 'none']:
rnn_layer = nn.OptimizedLSTMCell( rnn_layer = nn.OptimizedLSTMCell(
...@@ -594,7 +647,6 @@ class RNNAgent(nn.Module): ...@@ -594,7 +647,6 @@ class RNNAgent(nn.Module):
elif self.rnn_type == 'none': elif self.rnn_type == 'none':
f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1) f_state_r = jnp.concatenate([f_state for i in range(self.rnn_channels // c)], axis=-1)
else: else:
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = f_state.shape[0] // batch_size num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1 multi_step = num_steps > 1
...@@ -607,7 +659,11 @@ class RNNAgent(nn.Module): ...@@ -607,7 +659,11 @@ class RNNAgent(nn.Module):
f_state_r, done, switch_or_main = jax.tree.map( f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main)) lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state_r = rnn_forward_2p( rstate, f_state_r = rnn_forward_2p(
rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch) rnn_layer, rstate, f_state_r, done, switch_or_main, self.switch, return_state=self.oppo_info)
if self.oppo_info:
f_state_r, all_rstate = f_state_r
all_rstate = jax.tree.map(
lambda x: jnp.reshape(x, (-1, x.shape[-1])), all_rstate)
f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1])) f_state_r = f_state_r.reshape((-1, f_state_r.shape[-1]))
else: else:
rstate, f_state_r = rnn_step_by_main( rstate, f_state_r = rnn_step_by_main(
...@@ -619,11 +675,29 @@ class RNNAgent(nn.Module): ...@@ -619,11 +675,29 @@ class RNNAgent(nn.Module):
else: else:
actor = Actor( actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
if self.oppo_info:
critic = GlobalCritic(
channels=[c, c], dtype=self.dtype, param_dtype=self.param_dtype)
if not multi_step:
if isinstance(rstate[0], tuple):
rstate1_t, rstate2_t = rstate
else:
rstate1_t = rstate2_t = rstate
else:
main = switch_or_main.reshape(-1)[:, None]
rstate1, rstate2 = all_rstate
rstate1_t = jax.tree.map(
lambda x1, x2: jnp.where(main, x1, x2), rstate1, rstate2)
rstate2_t = jax.tree.map(
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g)
else:
critic = Critic( critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
value = critic(f_state_r) value = critic(f_state_r)
if self.int_head: if self.int_head:
critic_int = Critic( critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
...@@ -696,7 +770,7 @@ class RNDModel(nn.Module): ...@@ -696,7 +770,7 @@ class RNDModel(nn.Module):
version=self.version, version=self.version,
) )
f_actions, f_state, mask, valid = encoder(x) f_state = encoder(x)[1]
c = f_state.shape[-1] c = f_state.shape[-1]
if self.is_predictor: if self.is_predictor:
predictor = MLP([oc, oc], dtype=self.dtype, param_dtype=self.param_dtype) predictor = MLP([oc, oc], dtype=self.dtype, param_dtype=self.param_dtype)
......
...@@ -1527,7 +1527,8 @@ public: ...@@ -1527,7 +1527,8 @@ public:
"verbose"_.Bind(false), "max_options"_.Bind(16), "verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16), "max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false), "record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600)); "greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false));
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
...@@ -1539,6 +1540,7 @@ public: ...@@ -1539,6 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})), Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind( "obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})), Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:g_cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})), "info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})), "info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})), "info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
...@@ -2259,8 +2261,13 @@ public: ...@@ -2259,8 +2261,13 @@ public:
} }
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// if (spec_.config["oppo_info"_]) {
if (false) {
reward = winner_ == 0 ? base_reward : -base_reward;
} else {
// to_play_ is the previous player // to_play_ is the previous player
reward = winner_ == player ? base_reward : -base_reward; reward = winner_ == player ? base_reward : -base_reward;
}
} else { } else {
reward = winner_ == ai_player_ ? base_reward : -base_reward; reward = winner_ == ai_player_ ? base_reward : -base_reward;
} }
...@@ -2331,6 +2338,9 @@ public: ...@@ -2331,6 +2338,9 @@ public:
} }
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_); auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
if (spec_.config["oppo_info"_]) {
_set_obs_g_cards(state["obs:g_cards_"_]);
}
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards); _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
...@@ -2438,8 +2448,30 @@ private: ...@@ -2438,8 +2448,30 @@ private:
return {spec_infos, loc_n_cards}; return {spec_infos, loc_n_cards};
} }
void _set_obs_g_cards(TArray<uint8_t> &f_cards) {
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA,
};
for (auto location : configs) {
std::vector<Card> cards = get_cards_in_location(pi, location);
int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++;
}
}
}
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c, void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0) { bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards // check offset exceeds max_cards
uint8_t location = c.location_; uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY; bool overlay = location & LOCATION_OVERLAY;
...@@ -2462,7 +2494,7 @@ private: ...@@ -2462,7 +2494,7 @@ private:
seq = c.sequence_ + 1; seq = c.sequence_ + 1;
} }
f_cards(offset, 3) = seq; f_cards(offset, 3) = seq;
f_cards(offset, 4) = (c.controler_ != to_play_) ? 1 : 0; f_cards(offset, 4) = global ? c.controler_ : ((c.controler_ != to_play_) ? 1 : 0);
if (overlay) { if (overlay) {
f_cards(offset, 5) = position_to_id(POS_FACEUP); f_cards(offset, 5) = position_to_id(POS_FACEUP);
f_cards(offset, 6) = 1; f_cards(offset, 6) = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment