Commit cd2974ce authored by sbl1996@126.com's avatar sbl1996@126.com

Change block location

parent 14bceecd
...@@ -227,7 +227,7 @@ def create_agent(args, eval=False): ...@@ -227,7 +227,7 @@ def create_agent(args, eval=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
**asdict(args.m2), **asdict(args.m2),
) )
else: else:
return RNNAgent( return RNNAgent(
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
...@@ -315,8 +315,8 @@ def rollout( ...@@ -315,8 +315,8 @@ def rollout(
done = jnp.array(done) done = jnp.array(done)
main = jnp.array(main) main = jnp.array(main)
inputs = next_obs, (rstate1, rstate2), done, main (rstate1, rstate2), logits = agent.apply(
(rstate1, rstate2), logits = agent.apply(params, *inputs)[:2] params, next_obs, (rstate1, rstate2), done, main)[:2]
action, key = categorical_sample(logits, key) action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key return next_obs, done, main, rstate1, rstate2, action, logits, key
...@@ -360,9 +360,7 @@ def rollout( ...@@ -360,9 +360,7 @@ def rollout(
if args.concurrency: if args.concurrency:
if update != 2: if update != 2:
params = params_queue.get() params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][ params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
# "embedding"
# ].block_until_ready()
actor_policy_version += 1 actor_policy_version += 1
else: else:
params = params_queue.get() params = params_queue.get()
...@@ -627,7 +625,6 @@ def main(): ...@@ -627,7 +625,6 @@ def main():
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac return args.learning_rate * frac
# rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = agent.init(init_key, sample_obs, rstate) params = agent.init(init_key, sample_obs, rstate)
...@@ -687,8 +684,8 @@ def main(): ...@@ -687,8 +684,8 @@ def main():
if args.switch: if args.switch:
dones = dones | next_dones dones = dones | next_dones
inputs = obs, (rstate1, rstate2), dones, switch_or_mains new_logits, new_values = create_agent(args).apply(
new_logits, new_values = create_agent(args).apply(params, *inputs)[1:3] params, obs, (rstate1, rstate2), dones, switch_or_mains)[1:3]
new_values = new_values.squeeze(-1) new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
...@@ -938,7 +935,7 @@ def main(): ...@@ -938,7 +935,7 @@ def main():
params_queue_put_time = 0 params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id]) device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready() # device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start = time.time() params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params) params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment