Commit cd59a6e9 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix

parent 4ae87a70
......@@ -9,12 +9,12 @@ def masked_mean(x, valid):
return x.sum() / valid.sum()
def masked_normalize(x, valid, epsilon=1e-8):
def masked_normalize(x, valid, eps=1e-8):
x = jnp.where(valid, x, jnp.zeros_like(x))
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
return (x - mean) / jnp.sqrt(variance + eps)
def categorical_sample(logits, key):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment