Commit cd95bf43 authored by sbl1996@126.com's avatar sbl1996@126.com

Print value in eval

parent 14d32bc7
...@@ -157,17 +157,17 @@ if __name__ == "__main__": ...@@ -157,17 +157,17 @@ if __name__ == "__main__":
params = jax.device_put(params) params = jax.device_put(params)
@jax.jit @jax.jit
def get_probs(params, rstate, obs, done): def get_probs_and_value(params, rstate, obs, done):
agent = create_agent(args) agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2] next_rstate, logits, value = agent.apply(params, (rstate, obs))[:3]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map( next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate) lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs return next_rstate, probs, value
def predict_fn(rstate, obs, done): def predict_fn(rstate, obs, done):
rstate, probs = get_probs(params, rstate, obs, done) rstate, probs, value = get_probs_and_value(params, rstate, obs, done)
return rstate, np.array(probs) return rstate, np.array(probs), np.array(value)
print(f"loaded checkpoint from {args.checkpoint}") print(f"loaded checkpoint from {args.checkpoint}")
...@@ -194,9 +194,10 @@ if __name__ == "__main__": ...@@ -194,9 +194,10 @@ if __name__ == "__main__":
if args.checkpoint: if args.checkpoint:
_start = time.time() _start = time.time()
rstate, probs = predict_fn(rstate, obs, dones) rstate, probs, value = predict_fn(rstate, obs, dones)
if args.verbose: if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()]) print(f"probs: {[f'{p:.4f}' for p in probs[probs != 0].tolist()]}")
print(f"value: {value[0][0]}")
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment