Commit 671ed3c6 authored by sbl1996@126.com's avatar sbl1996@126.com

Replace tree_map with tree.map

parent 9670ed68
......@@ -207,7 +207,7 @@ if __name__ == "__main__":
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
print(jax.tree.leaves(params)[0].devices())
with open(args.checkpoint1, "rb") as f:
......
......@@ -224,7 +224,7 @@ if __name__ == "__main__":
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
......
......@@ -153,7 +153,7 @@ if __name__ == "__main__":
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
......@@ -171,7 +171,7 @@ if __name__ == "__main__":
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree_map(
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
......
......@@ -237,7 +237,7 @@ def rollout(
next_obs,
key: jax.random.PRNGKey,
):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs)
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0]
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
......@@ -263,7 +263,7 @@ def rollout(
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
......@@ -469,7 +469,7 @@ if __name__ == "__main__":
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
sample_obs = jax.tree.map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
envs.close()
del envs
......@@ -579,7 +579,7 @@ if __name__ == "__main__":
sharded_storages: List[Transition],
key: jax.random.PRNGKey,
):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True)
def update_minibatch(agent_state, minibatch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment