Commit e6dc7744 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix eval

parent afbe4893
......@@ -336,7 +336,7 @@ def rollout(
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params, params_rt = params_queue.get()
params_rt, eval_params = params_queue.get()
else:
params_rt, = params_queue.get()
......
......@@ -37,10 +37,11 @@ class RunningMeanStd(struct.PyTreeNode):
@classmethod
def create(cls, shape=()):
# TODO: use numpy and float64
return cls(
mean=jnp.zeros(shape, "float64"),
var=jnp.ones(shape, "float64"),
count=jnp.full(shape, 1e-4, "float64"),
mean=jnp.zeros(shape, "float32"),
var=jnp.ones(shape, "float32"),
count=jnp.full(shape, 1e-4, "float32"),
)
def update(self, x):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment