Commit b0d45e40 authored by sbl1996@126.com's avatar sbl1996@126.com

More bfloat16

parent a1329e4f
...@@ -78,6 +78,7 @@ class EnvPreprocess(gym.Wrapper): ...@@ -78,6 +78,7 @@ class EnvPreprocess(gym.Wrapper):
def __init__(self, env, skip_mask): def __init__(self, env, skip_mask):
super().__init__(env) super().__init__(env)
self.num_envs = env.num_envs
self.skip_mask = skip_mask self.skip_mask = skip_mask
def reset(self, **kwargs): def reset(self, **kwargs):
......
...@@ -90,7 +90,7 @@ class CardEncoder(nn.Module): ...@@ -90,7 +90,7 @@ class CardEncoder(nn.Module):
assert self.version > 0 assert self.version > 0
c = self.channels c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True) layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial( embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init) nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype) fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
...@@ -100,7 +100,7 @@ class CardEncoder(nn.Module): ...@@ -100,7 +100,7 @@ class CardEncoder(nn.Module):
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals)) num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32) x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32) x2 = x[:, :, 10:].astype(self.dtype)
x_loc = x1[:, :, 0] x_loc = x1[:, :, 0]
x_seq = x1[:, :, 1] x_seq = x1[:, :, 1]
...@@ -158,12 +158,16 @@ class CardEncoder(nn.Module): ...@@ -158,12 +158,16 @@ class CardEncoder(nn.Module):
x_cards = jnp.concatenate(feats[1:], axis=-1) x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards) x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * feats[0] x_cards = x_cards * feats[0]
print("before", x_cards.dtype)
f_cards = layer_norm()(x_cards) f_cards = layer_norm()(x_cards)
# f_cards = f_cards.astype(self.dtype)
print("norm", f_cards.dtype)
if self.oppo_info: if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1) x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g) x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0] x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g) f_cards_g = layer_norm()(x_cards_g)
# f_cards_g = f_cards_g.astype(self.dtype)
else: else:
f_cards_g = None f_cards_g = None
return f_cards_g, f_cards, c_mask return f_cards_g, f_cards, c_mask
...@@ -180,7 +184,7 @@ class GlobalEncoder(nn.Module): ...@@ -180,7 +184,7 @@ class GlobalEncoder(nn.Module):
batch_size = x.shape[0] batch_size = x.shape[0]
c = self.channels c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True) layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial( embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init) nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype) fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
...@@ -192,7 +196,7 @@ class GlobalEncoder(nn.Module): ...@@ -192,7 +196,7 @@ class GlobalEncoder(nn.Module):
bin_points, bin_intervals = make_bin_params(n_bins=32) bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals)) num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32) x1 = x[:, :4].astype(self.dtype)
x2 = x[:, 4:8].astype(jnp.int32) x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32) x3 = x[:, 8:22].astype(jnp.int32)
...@@ -241,18 +245,18 @@ class Encoder(nn.Module): ...@@ -241,18 +245,18 @@ class Encoder(nn.Module):
n_embed, embed_dim = self.embedding_shape n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True) layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
embed = partial( embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init) nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype) fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype, dtype=self.dtype)
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
card_encoder = CardEncoder( card_encoder = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, channels=c, dtype=self.dtype, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info) version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1 ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls( action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
x_cards = x['cards_'] x_cards = x['cards_']
x_global = x['global_'] x_global = x['global_']
...@@ -288,27 +292,26 @@ class Encoder(nn.Module): ...@@ -288,27 +292,26 @@ class Encoder(nn.Module):
c_mask = None c_mask = None
num_heads = max(2, c // 128) num_heads = max(2, c // 128)
for _ in range(self.num_layers): for i in range(self.num_layers):
f_cards = get_encoder_layer_cls( f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask) f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards) f_cards = layer_norm()(f_cards)
f_g_card = f_cards[:, 0] f_g_card = f_cards[:, 0]
fs_g_card.append(f_g_card) fs_g_card.append(f_g_card)
f_g_g_card, f_g_card = fs_g_card f_g_g_card, f_g_card = fs_g_card
# Global # Global
x_global = GlobalEncoder( x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global) channels=c, dtype=self.dtype, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype)
if self.version == 2: if self.version == 2:
x_global = fc_layer(c, dtype=jnp.float32)(x_global) x_global = fc_layer(c)(x_global)
f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)( f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global)) layer_norm()(x_global))
else: else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global) f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global) f_global = fc_layer(c)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global) f_global = layer_norm()(f_global)
# History actions # History actions
x_h_actions = x_h_actions.astype(jnp.int32) x_h_actions = x_h_actions.astype(jnp.int32)
...@@ -321,7 +324,7 @@ class Encoder(nn.Module): ...@@ -321,7 +324,7 @@ class Encoder(nn.Module):
if self.freeze_id: if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id) x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP( x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype, (c, c), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id) kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13]) x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
...@@ -331,13 +334,13 @@ class Encoder(nn.Module): ...@@ -331,13 +334,13 @@ class Encoder(nn.Module):
x_h_a_feats = jnp.concatenate([ x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1) *x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats)) f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions) f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers): for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask) f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0]) f_g_h_actions = layer_norm()(f_h_actions[:, 0])
else: else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0 h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False) h_mask = h_mask.at[:, 0].set(False)
...@@ -347,7 +350,7 @@ class Encoder(nn.Module): ...@@ -347,7 +350,7 @@ class Encoder(nn.Module):
if self.freeze_id: if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id) x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = fc_layer(c, dtype=jnp.float32)(x_h_id) x_h_id = fc_layer(c)(x_h_id)
x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12]) x_h_a_feats = action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12]) x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 12])
...@@ -355,7 +358,7 @@ class Encoder(nn.Module): ...@@ -355,7 +358,7 @@ class Encoder(nn.Module):
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase]) x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1) x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = layer_norm()(x_h_a_feats) x_h_a_feats = layer_norm()(x_h_a_feats)
x_h_a_feats = fc_layer(c, dtype=self.dtype)(x_h_a_feats) x_h_a_feats = fc_layer(c)(x_h_a_feats)
if self.noam: if self.noam:
f_h_actions = LlamaEncoderLayer( f_h_actions = LlamaEncoderLayer(
...@@ -365,7 +368,7 @@ class Encoder(nn.Module): ...@@ -365,7 +368,7 @@ class Encoder(nn.Module):
x_h_a_feats = PositionalEncoding()(x_h_a_feats) x_h_a_feats = PositionalEncoding()(x_h_a_feats)
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
x_h_a_feats, src_key_padding_mask=h_mask) x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0]) f_g_h_actions = layer_norm()(f_h_actions[:, 0])
# Actions # Actions
...@@ -382,12 +385,12 @@ class Encoder(nn.Module): ...@@ -382,12 +385,12 @@ class Encoder(nn.Module):
spec_index = decode_id(x_actions[..., :2]) spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size) B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index] f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards) f_a_cards = fc_layer(c)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1) x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats) x_a_feats = fc_layer(c)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1) f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1)) f_actions = fc_layer(c)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions) f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0 a_mask = x_actions[:, :, 2] == 0
...@@ -408,16 +411,16 @@ class Encoder(nn.Module): ...@@ -408,16 +411,16 @@ class Encoder(nn.Module):
x_a_id = id_embed(x_a_id) x_a_id = id_embed(x_a_id)
if self.freeze_id: if self.freeze_id:
x_a_id = jax.lax.stop_gradient(x_a_id) x_a_id = jax.lax.stop_gradient(x_a_id)
x_a_id = fc_layer(c, dtype=jnp.float32)(x_a_id) x_a_id = fc_layer(c)(x_a_id)
x_a_feats = action_encoder(x_actions[..., 3:]) x_a_feats = action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id) x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1) x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = layer_norm()(x_a_feats) x_a_feats = layer_norm()(x_a_feats)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats) x_a_feats = fc_layer(c)(x_a_feats)
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards) f_a_cards = fc_layer(c)(f_a_cards)
f_actions = jax.nn.silu(f_a_cards) * x_a_feats f_actions = jax.nn.silu(f_a_cards) * x_a_feats
f_actions = fc_layer(c, dtype=self.dtype)(f_actions) f_actions = fc_layer(c)(f_actions)
f_actions = x_a_feats + f_actions f_actions = x_a_feats + f_actions
a_mask = x_actions[:, :, 3] == 0 a_mask = x_actions[:, :, 3] == 0
...@@ -428,11 +431,12 @@ class Encoder(nn.Module): ...@@ -428,11 +431,12 @@ class Encoder(nn.Module):
g_feats.append(f_g_h_actions) g_feats.append(f_g_h_actions)
if self.action_feats: if self.action_feats:
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions) f_actions_g = fc_layer(c)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype)) a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1) f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True) f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions) g_feats.append(f_g_actions)
print("f_g_actions", f_g_actions.dtype)
f_state = jnp.concatenate(g_feats, axis=-1) f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c oc = self.out_channels or c
...@@ -442,7 +446,8 @@ class Encoder(nn.Module): ...@@ -442,7 +446,8 @@ class Encoder(nn.Module):
dtype=self.dtype, param_dtype=self.param_dtype)(f_state) dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
else: else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) print("f_state", f_state.dtype)
f_state = layer_norm()(f_state)
return f_actions, f_state, f_g_g_card, a_mask, valid return f_actions, f_state, f_g_g_card, a_mask, valid
...@@ -732,7 +737,7 @@ class RNNAgent(nn.Module): ...@@ -732,7 +737,7 @@ class RNNAgent(nn.Module):
CriticCls = CrossCritic if self.batch_norm else Critic CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth cs = [self.critic_width] * self.critic_depth
critic = CriticCls( critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype) channels=cs, dtype=jnp.float32, param_dtype=self.param_dtype)
if self.oppo_info: if self.oppo_info:
if not multi_step: if not multi_step:
if isinstance(rstate[0], tuple): if isinstance(rstate[0], tuple):
...@@ -754,7 +759,7 @@ class RNNAgent(nn.Module): ...@@ -754,7 +759,7 @@ class RNNAgent(nn.Module):
if self.int_head: if self.int_head:
cs = [self.critic_width] * self.critic_depth cs = [self.critic_width] * self.critic_depth
critic_int = Critic( critic_int = Critic(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype) channels=cs, dtype=jnp.float32, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r) value_int = critic_int(f_state_r)
value = (value, value_int) value = (value, value_int)
return rstate, logits, value, valid return rstate, logits, value, valid
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment