Commit e6dc7744 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix eval

parent afbe4893
...@@ -336,7 +336,7 @@ def rollout( ...@@ -336,7 +336,7 @@ def rollout(
): ):
eval_mode = 'self' if args.eval_checkpoint else 'bot' eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot': if eval_mode != 'bot':
eval_params, params_rt = params_queue.get() params_rt, eval_params = params_queue.get()
else: else:
params_rt, = params_queue.get() params_rt, = params_queue.get()
......
...@@ -37,10 +37,11 @@ class RunningMeanStd(struct.PyTreeNode): ...@@ -37,10 +37,11 @@ class RunningMeanStd(struct.PyTreeNode):
@classmethod @classmethod
def create(cls, shape=()): def create(cls, shape=()):
# TODO: use numpy and float64
return cls( return cls(
mean=jnp.zeros(shape, "float64"), mean=jnp.zeros(shape, "float32"),
var=jnp.ones(shape, "float64"), var=jnp.ones(shape, "float32"),
count=jnp.full(shape, 1e-4, "float64"), count=jnp.full(shape, 1e-4, "float32"),
) )
def update(self, x): def update(self, x):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment