Commit cd2974ce authored by sbl1996@126.com's avatar sbl1996@126.com

Change block location

parent 14bceecd
......@@ -227,7 +227,7 @@ def create_agent(args, eval=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
**asdict(args.m2),
)
)
else:
return RNNAgent(
embedding_shape=args.num_embeddings,
......@@ -315,8 +315,8 @@ def rollout(
done = jnp.array(done)
main = jnp.array(main)
inputs = next_obs, (rstate1, rstate2), done, main
(rstate1, rstate2), logits = agent.apply(params, *inputs)[:2]
(rstate1, rstate2), logits = agent.apply(
params, next_obs, (rstate1, rstate2), done, main)[:2]
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
......@@ -360,9 +360,7 @@ def rollout(
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
......@@ -627,7 +625,6 @@ def main():
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
# rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
rstate = agent.init_rnn_state(1)
params = agent.init(init_key, sample_obs, rstate)
......@@ -687,8 +684,8 @@ def main():
if args.switch:
dones = dones | next_dones
inputs = obs, (rstate1, rstate2), dones, switch_or_mains
new_logits, new_values = create_agent(args).apply(params, *inputs)[1:3]
new_logits, new_values = create_agent(args).apply(
params, obs, (rstate1, rstate2), dones, switch_or_mains)[1:3]
new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
......@@ -938,7 +935,7 @@ def main():
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
# device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment