Commit 81f63996 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix random seed

parent 3e538bc7
...@@ -229,9 +229,12 @@ def rollout( ...@@ -229,9 +229,12 @@ def rollout(
if eval_mode != 'bot': if eval_mode != 'bot':
eval_params = params_queue.get() eval_params = params_queue.get()
local_seed = args.seed + device_thread_id
np.random.seed(local_seed)
envs = make_env( envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, local_seed,
args.local_num_envs, args.local_num_envs,
args.local_env_threads, args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads, thread_affinity_offset=device_thread_id * args.local_env_threads,
...@@ -240,7 +243,7 @@ def rollout( ...@@ -240,7 +243,7 @@ def rollout(
eval_envs = make_env( eval_envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, local_seed,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True) args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
...@@ -542,11 +545,14 @@ if __name__ == "__main__": ...@@ -542,11 +545,14 @@ if __name__ == "__main__":
args.ckpt_dir, save_fn, n_saved=3) args.ckpt_dir, save_fn, n_saved=3)
# seeding # seeding
seed_offset = args.local_rank * 10000
args.seed += seed_offset
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_replicated(key, learner_devices) learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
...@@ -569,7 +575,7 @@ if __name__ == "__main__": ...@@ -569,7 +575,7 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels) rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs)) params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None: if embeddings is not None:
unknown_embed = embeddings.mean(axis=0) unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0) embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
...@@ -775,18 +781,19 @@ if __name__ == "__main__": ...@@ -775,18 +781,19 @@ if __name__ == "__main__":
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
if eval_params: if eval_params:
params_queues[-1].put( params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id])) jax.device_put(eval_params, local_devices[d_id]))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread( threading.Thread(
target=rollout, target=rollout,
args=( args=(
jax.device_put(key, local_devices[d_id]), jax.device_put(actor_keys[actor_thread_id], local_devices[d_id]),
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
eval_queue, eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
d_idx * args.num_actor_threads + thread_id, actor_thread_id,
), ),
).start() ).start()
params_queues[-1].put(device_params) params_queues[-1].put(device_params)
......
...@@ -229,9 +229,12 @@ def rollout( ...@@ -229,9 +229,12 @@ def rollout(
if eval_mode != 'bot': if eval_mode != 'bot':
eval_params = params_queue.get() eval_params = params_queue.get()
local_seed = args.seed + device_thread_id
np.random.seed(local_seed)
envs = make_env( envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, local_seed,
args.local_num_envs, args.local_num_envs,
args.local_env_threads, args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads, thread_affinity_offset=device_thread_id * args.local_env_threads,
...@@ -240,7 +243,7 @@ def rollout( ...@@ -240,7 +243,7 @@ def rollout(
eval_envs = make_env( eval_envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, local_seed,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True) args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
...@@ -552,11 +555,14 @@ if __name__ == "__main__": ...@@ -552,11 +555,14 @@ if __name__ == "__main__":
args.ckpt_dir, save_fn, n_saved=3) args.ckpt_dir, save_fn, n_saved=3)
# seeding # seeding
seed_offset = args.local_rank * 10000
args.seed += seed_offset
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_replicated(key, learner_devices) learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
...@@ -579,7 +585,7 @@ if __name__ == "__main__": ...@@ -579,7 +585,7 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels) rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs)) params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None: if embeddings is not None:
unknown_embed = embeddings.mean(axis=0) unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0) embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
...@@ -801,18 +807,19 @@ if __name__ == "__main__": ...@@ -801,18 +807,19 @@ if __name__ == "__main__":
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
if eval_params: if eval_params:
params_queues[-1].put( params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id])) jax.device_put(eval_params, local_devices[d_id]))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread( threading.Thread(
target=rollout, target=rollout,
args=( args=(
jax.device_put(key, local_devices[d_id]), jax.device_put(actor_keys[actor_thread_id], local_devices[d_id]),
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
eval_queue, eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
d_idx * args.num_actor_threads + thread_id, actor_thread_id,
), ),
).start() ).start()
params_queues[-1].put(device_params) params_queues[-1].put(device_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment