Commit e7d409ec authored by sbl1996@126.com's avatar sbl1996@126.com

Add agent version 2

parent f94a7fc3
......@@ -406,7 +406,12 @@ class Encoder(nn.Module):
f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
if self.version == 2:
f_state = GLUMlp(
intermediate_size=c * 2, output_size=oc,
dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid
......
......@@ -57,6 +57,7 @@ class MLP(nn.Module):
class GLUMlp(nn.Module):
intermediate_size: int
output_size: Optional[int] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
......@@ -74,8 +75,7 @@ class GLUMlp(nn.Module):
kernel_init=self.kernel_init,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
output_size = self.output_size or inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
......@@ -86,7 +86,7 @@ class GLUMlp(nn.Module):
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
features=output_size,
name="down",
)(x)
return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment