Commit bcf43ac9 authored by sbl1996@126.com's avatar sbl1996@126.com

Remove print

parent b0d45e40
......@@ -158,10 +158,8 @@ class CardEncoder(nn.Module):
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * feats[0]
print("before", x_cards.dtype)
f_cards = layer_norm()(x_cards)
# f_cards = f_cards.astype(self.dtype)
print("norm", f_cards.dtype)
if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
......@@ -436,7 +434,6 @@ class Encoder(nn.Module):
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions)
print("f_g_actions", f_g_actions.dtype)
f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c
......@@ -446,7 +443,6 @@ class Encoder(nn.Module):
dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
print("f_state", f_state.dtype)
f_state = layer_norm()(f_state)
return f_actions, f_state, f_g_g_card, a_mask, valid
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment