Commit cd95bf43 authored by sbl1996@126.com's avatar sbl1996@126.com

Print value in eval

parent 14d32bc7
......@@ -157,17 +157,17 @@ if __name__ == "__main__":
params = jax.device_put(params)
@jax.jit
def get_probs(params, rstate, obs, done):
def get_probs_and_value(params, rstate, obs, done):
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
return next_rstate, probs, value
def predict_fn(rstate, obs, done):
rstate, probs = get_probs(params, rstate, obs, done)
return rstate, np.array(probs)
rstate, probs, value = get_probs_and_value(params, rstate, obs, done)
return rstate, np.array(probs), np.array(value)
print(f"loaded checkpoint from {args.checkpoint}")
......@@ -194,9 +194,10 @@ if __name__ == "__main__":
if args.checkpoint:
_start = time.time()
rstate, probs = predict_fn(rstate, obs, dones)
rstate, probs, value = predict_fn(rstate, obs, dones)
if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
print(f"probs: {[f'{p:.4f}' for p in probs[probs != 0].tolist()]}")
print(f"value: {value[0][0]}")
actions = probs.argmax(axis=1)
model_time += time.time() - _start
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment