Commit b5d92a22 authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor random seed

parent 2a419375
......@@ -184,6 +184,7 @@ class Args:
num_embeddings: Optional[int] = None
freeze_id: Optional[bool] = None
deck_names: Optional[List[str]] = None
real_seed: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
......@@ -259,7 +260,7 @@ def rollout(
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.seed + device_thread_id * 100
local_seed = args.real_seed + device_thread_id * args.local_num_envs
np.random.seed(local_seed)
envs = make_env(
......@@ -273,7 +274,7 @@ def rollout(
eval_envs = make_env(
args,
local_seed + 10000,
local_seed + 100000,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
......@@ -595,11 +596,12 @@ def main():
args.ckpt_dir, save_fn, n_saved=2)
# seeding
seed_offset = args.local_rank * 1000
seed_offset = args.local_rank
args.seed += seed_offset
random.seed(args.seed)
args.real_seed = random.randint(0, 1e8)
init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed)
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
......@@ -610,7 +612,7 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
envs = make_env(args, 0, 2, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
......
......@@ -191,6 +191,7 @@ class Args:
num_embeddings: Optional[int] = None
freeze_id: Optional[bool] = None
deck_names: Optional[List[str]] = None
real_seed: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
......@@ -266,7 +267,7 @@ def rollout(
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.seed + device_thread_id * 100
local_seed = args.real_seed + device_thread_id * args.local_num_envs
np.random.seed(local_seed)
envs = make_env(
......@@ -280,7 +281,7 @@ def rollout(
eval_envs = make_env(
args,
local_seed + 10000,
local_seed + 100000,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
......@@ -619,11 +620,12 @@ def main():
args.ckpt_dir, save_fn, n_saved=2)
# seeding
seed_offset = args.local_rank * 1000
seed_offset = args.local_rank
args.seed += seed_offset
random.seed(args.seed)
args.real_seed = random.randint(0, 1e8)
init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed)
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
......@@ -634,7 +636,7 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
envs = make_env(args, 0, 2, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment