Commit 81f63996 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix random seed

parent 3e538bc7
......@@ -229,9 +229,12 @@ def rollout(
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.seed + device_thread_id
np.random.seed(local_seed)
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
local_seed,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
......@@ -240,7 +243,7 @@ def rollout(
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
local_seed,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
......@@ -542,11 +545,14 @@ if __name__ == "__main__":
args.ckpt_dir, save_fn, n_saved=3)
# seeding
seed_offset = args.local_rank * 10000
args.seed += seed_offset
random.seed(args.seed)
np.random.seed(args.seed)
init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
......@@ -569,7 +575,7 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
......@@ -776,17 +782,18 @@ if __name__ == "__main__":
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
jax.device_put(actor_keys[actor_thread_id], local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
actor_thread_id,
),
).start()
params_queues[-1].put(device_params)
......
......@@ -229,9 +229,12 @@ def rollout(
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.seed + device_thread_id
np.random.seed(local_seed)
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
local_seed,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
......@@ -240,7 +243,7 @@ def rollout(
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
local_seed,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
......@@ -552,11 +555,14 @@ if __name__ == "__main__":
args.ckpt_dir, save_fn, n_saved=3)
# seeding
seed_offset = args.local_rank * 10000
args.seed += seed_offset
random.seed(args.seed)
np.random.seed(args.seed)
init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
......@@ -579,7 +585,7 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
......@@ -802,17 +808,18 @@ if __name__ == "__main__":
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
jax.device_put(actor_keys[actor_thread_id], local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
actor_thread_id,
),
).start()
params_queues[-1].put(device_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment