Commit 671ed3c6 authored by sbl1996@126.com's avatar sbl1996@126.com

Replace tree_map with tree.map

parent 9670ed68
...@@ -207,7 +207,7 @@ if __name__ == "__main__": ...@@ -207,7 +207,7 @@ if __name__ == "__main__":
agent = create_agent(args) agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs) params = agent.init(agent_key, sample_obs)
print(jax.tree.leaves(params)[0].devices()) print(jax.tree.leaves(params)[0].devices())
with open(args.checkpoint1, "rb") as f: with open(args.checkpoint1, "rb") as f:
......
...@@ -224,7 +224,7 @@ if __name__ == "__main__": ...@@ -224,7 +224,7 @@ if __name__ == "__main__":
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs) params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
......
...@@ -153,7 +153,7 @@ if __name__ == "__main__": ...@@ -153,7 +153,7 @@ if __name__ == "__main__":
agent = create_agent(args) agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels) rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs)) params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
...@@ -171,7 +171,7 @@ if __name__ == "__main__": ...@@ -171,7 +171,7 @@ if __name__ == "__main__":
agent = create_agent(args) agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2] next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree_map( next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate) lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs return next_rstate, probs
......
...@@ -237,7 +237,7 @@ def rollout( ...@@ -237,7 +237,7 @@ def rollout(
next_obs, next_obs,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
): ):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0] logits = apply_fn(params, next_obs)[0]
# sample action: Gumbel-softmax trick # sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
...@@ -263,7 +263,7 @@ def rollout( ...@@ -263,7 +263,7 @@ def rollout(
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage) return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2): for update in range(1, args.num_updates + 2):
if update == 10: if update == 10:
...@@ -469,7 +469,7 @@ if __name__ == "__main__": ...@@ -469,7 +469,7 @@ if __name__ == "__main__":
obs_space = envs.observation_space obs_space = envs.observation_space
action_shape = envs.action_space.shape action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}") print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
envs.close() envs.close()
del envs del envs
...@@ -579,7 +579,7 @@ if __name__ == "__main__": ...@@ -579,7 +579,7 @@ if __name__ == "__main__":
sharded_storages: List[Transition], sharded_storages: List[Transition],
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
): ):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True) impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True)
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment