Commit e7d409ec authored by sbl1996@126.com's avatar sbl1996@126.com

Add agent version 2

parent f94a7fc3
...@@ -406,7 +406,12 @@ class Encoder(nn.Module): ...@@ -406,7 +406,12 @@ class Encoder(nn.Module):
f_state = jnp.concatenate(g_feats, axis=-1) f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) if self.version == 2:
f_state = GLUMlp(
intermediate_size=c * 2, output_size=oc,
dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
else:
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state) f_state = layer_norm(dtype=self.dtype)(f_state)
return f_actions, f_state, a_mask, valid return f_actions, f_state, a_mask, valid
......
...@@ -57,6 +57,7 @@ class MLP(nn.Module): ...@@ -57,6 +57,7 @@ class MLP(nn.Module):
class GLUMlp(nn.Module): class GLUMlp(nn.Module):
intermediate_size: int intermediate_size: int
output_size: Optional[int] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
...@@ -74,8 +75,7 @@ class GLUMlp(nn.Module): ...@@ -74,8 +75,7 @@ class GLUMlp(nn.Module):
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
) for _ in range(3) ) for _ in range(3)
] ]
output_size = self.output_size or inputs.shape[-1]
actual_out_dim = inputs.shape[-1]
g = dense[0]( g = dense[0](
features=self.intermediate_size, features=self.intermediate_size,
name="gate", name="gate",
...@@ -86,7 +86,7 @@ class GLUMlp(nn.Module): ...@@ -86,7 +86,7 @@ class GLUMlp(nn.Module):
name="up", name="up",
)(inputs) )(inputs)
x = dense[2]( x = dense[2](
features=actual_out_dim, features=output_size,
name="down", name="down",
)(x) )(x)
return x return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment