Commit b0d45e40 authored by sbl1996@126.com's avatar sbl1996@126.com

More bfloat16

parent a1329e4f
......@@ -78,6 +78,7 @@ class EnvPreprocess(gym.Wrapper):
def __init__(self, env, skip_mask):
super().__init__(env)
self.num_envs = env.num_envs
self.skip_mask = skip_mask
def reset(self, **kwargs):
......
......@@ -90,7 +90,7 @@ class CardEncoder(nn.Module):
assert self.version > 0
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
......@@ -100,7 +100,7 @@ class CardEncoder(nn.Module):
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x2 = x[:, :, 10:].astype(self.dtype)
x_loc = x1[:, :, 0]
x_seq = x1[:, :, 1]
......@@ -158,12 +158,16 @@ class CardEncoder(nn.Module):
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * feats[0]
print("before", x_cards.dtype)
f_cards = layer_norm()(x_cards)
# f_cards = f_cards.astype(self.dtype)
print("norm", f_cards.dtype)
if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g)
# f_cards_g = f_cards_g.astype(self.dtype)
else:
f_cards_g = None
return f_cards_g, f_cards, c_mask
......@@ -180,7 +184,7 @@ class GlobalEncoder(nn.Module):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
......@@ -192,7 +196,7 @@ class GlobalEncoder(nn.Module):
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x1 = x[:, :4].astype(self.dtype)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
......@@ -241,18 +245,18 @@ class Encoder(nn.Module):
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype, dtype=self.dtype)
id_embed = embed(n_embed, embed_dim)
card_encoder = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype,
channels=c, dtype=self.dtype, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
......@@ -288,27 +292,26 @@ class Encoder(nn.Module):
c_mask = None
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
for i in range(self.num_layers):
f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_cards = layer_norm()(f_cards)
f_g_card = f_cards[:, 0]
fs_g_card.append(f_g_card)
f_g_g_card, f_g_card = fs_g_card
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype)
channels=c, dtype=self.dtype, param_dtype=self.param_dtype, version=self.version)(x_global)
if self.version == 2:
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
x_global = fc_layer(c)(x_global)
f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
layer_norm()(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
......@@ -321,7 +324,7 @@ class Encoder(nn.Module):
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
(c, c), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
......@@ -331,13 +334,13 @@ class Encoder(nn.Module):
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
......@@ -347,7 +350,7 @@ class Encoder(nn.Module):
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c, dtype=jnp.float32)(x_h_id)
x_h_id = fc_layer(c)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
......@@ -355,7 +358,7 @@ class Encoder(nn.Module):
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c, dtype=self.dtype)(x_h_a_feats)
x_h_a_feats = fc_layer(c)(x_h_a_feats)
if self.noam:
f_h_actions = LlamaEncoderLayer(
......@@ -365,7 +368,7 @@ class Encoder(nn.Module):
x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
# Actions
......@@ -382,12 +385,12 @@ class Encoder(nn.Module):
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
f_a_cards = fc_layer(c)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
x_a_feats = fc_layer(c)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = fc_layer(c)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
......@@ -408,16 +411,16 @@ class Encoder(nn.Module):
x_a_id = id_embed(x_a_id)
if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c, dtype=jnp.float32)(x_a_id)
x_a_id = fc_layer(c)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = fc_layer(c)(x_a_feats)
f_a_cards = fc_layer(c)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c, dtype=self.dtype)(f_actions)
f_actions = fc_layer(c)(f_actions)
f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0
......@@ -428,11 +431,12 @@ class Encoder(nn.Module):
g_feats.append(f_g_h_actions)
if self.action_feats:
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions)
f_actions_g = fc_layer(c)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions)
print("f_g_actions", f_g_actions.dtype)
f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c
......@@ -442,7 +446,8 @@ class Encoder(nn.Module):
dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
print("f_state", f_state.dtype)
f_state = layer_norm()(f_state)
return f_actions, f_state, f_g_g_card, a_mask, valid
......@@ -732,7 +737,7 @@ class RNNAgent(nn.Module):
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
channels=cs, dtype=jnp.float32, param_dtype=self.param_dtype)
if self.oppo_info:
if not multi_step:
if isinstance(rstate[0], tuple):
......@@ -754,7 +759,7 @@ class RNNAgent(nn.Module):
if self.int_head:
cs = [self.critic_width] * self.critic_depth
critic_int = Critic(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
channels=cs, dtype=jnp.float32, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r)
value = (value, value_int)
return rstate, logits, value, valid
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment