Commit 112f1bc6 authored by sbl1996@126.com's avatar sbl1996@126.com

Use c*2 for version 2

parent e7d409ec
......@@ -273,7 +273,7 @@ class Encoder(nn.Module):
if self.version == 2:
x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c, dtype=self.dtype, param_dtype=self.param_dtype)(
f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global))
else:
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment