Commit d43e5903 authored by sbl1996@126.com's avatar sbl1996@126.com

version 2 no leaky relu

parent c90562de
...@@ -271,7 +271,6 @@ class Encoder(nn.Module): ...@@ -271,7 +271,6 @@ class Encoder(nn.Module):
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)(x_global)
x_global = x_global.astype(self.dtype) x_global = x_global.astype(self.dtype)
if self.version == 2: if self.version == 2:
x_global = jax.nn.leaky_relu(x_global, negative_slope=0.1)
x_global = fc_layer(c, dtype=jnp.float32)(x_global) x_global = fc_layer(c, dtype=jnp.float32)(x_global)
f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)( f_global = x_global + GLUMlp(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(
layer_norm(dtype=self.dtype)(x_global)) layer_norm(dtype=self.dtype)(x_global))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment