Commit 662b300f authored by sbl1996@126.com's avatar sbl1996@126.com

Update doc and defaults for release

parent 03416f14
This diff is collapsed.
...@@ -131,7 +131,7 @@ if __name__ == "__main__": ...@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed = args.seed + 100000 seed = args.seed + 100000
random.seed(seed) random.seed(seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -165,6 +165,7 @@ if __name__ == "__main__": ...@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info=args.oppo_info, oppo_info=args.oppo_info,
**env_option, **env_option,
) )
envs1.num_envs = num_envs
envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info) envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info)
if cross_env: if cross_env:
...@@ -175,11 +176,11 @@ if __name__ == "__main__": ...@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2=deck2, deck2=deck2,
**env_option, **env_option,
) )
envs2.num_envs = num_envs
key = jax.random.PRNGKey(seed) key = jax.random.PRNGKey(seed)
obs_space1 = envs1.observation_space obs_space1 = envs1.observation_space
envs1.num_envs = num_envs
envs1 = RecordEpisodeStatistics(envs1) envs1 = RecordEpisodeStatistics(envs1)
sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample()) sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample())
agent1 = create_agent1(args) agent1 = create_agent1(args)
...@@ -190,7 +191,6 @@ if __name__ == "__main__": ...@@ -190,7 +191,6 @@ if __name__ == "__main__":
if cross_env: if cross_env:
obs_space2 = envs2.observation_space obs_space2 = envs2.observation_space
envs2.num_envs = num_envs
envs2 = RecordEpisodeStatistics(envs2) envs2 = RecordEpisodeStatistics(envs2)
sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample()) sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample())
else: else:
......
...@@ -106,7 +106,7 @@ class Args: ...@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma""" """the discount factor gamma"""
num_minibatches: int = 64 num_minibatches: int = 64
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2 update_epochs: int = 1
"""the K epochs to update the policy""" """the K epochs to update the policy"""
switch: bool = False switch: bool = False
"""Toggle the use of switch mechanism""" """Toggle the use of switch mechanism"""
...@@ -119,7 +119,7 @@ class Args: ...@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages""" """Toggle the use of UPGO for advantages"""
sep_value: bool = True sep_value: bool = True
"""Whether separate value function computation for each player""" """Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace" value: Literal["vtrace", "gae"] = "gae"
"""the method to learn the value function""" """the method to learn the value function"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
...@@ -715,14 +715,14 @@ def main(): ...@@ -715,14 +715,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -716,14 +716,14 @@ def main(): ...@@ -716,14 +716,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -743,14 +743,14 @@ def main(): ...@@ -743,14 +743,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -96,7 +96,7 @@ if __name__ == "__main__": ...@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed = args.seed + 100000 seed = args.seed + 100000
random.seed(seed) random.seed(seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
......
This diff is collapsed.
...@@ -646,11 +646,11 @@ class EncoderArgs: ...@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
card_mask: bool = False card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer""" """whether to mask the padding card as ignored in the transformer"""
noam: bool = False noam: bool = True
"""whether to use Noam architecture for the transformer layer""" """whether to use Noam architecture for the transformer layer"""
action_feats: bool = True action_feats: bool = True
"""whether to use action features for the global state""" """whether to use action features for the global state"""
version: int = 0 version: int = 2
"""the version of the environment and the agent""" """the version of the environment and the agent"""
...@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs): ...@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent""" """the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm" rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = True
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
rnn_shortcut: bool = False rnn_shortcut: bool = False
"""whether to use shortcut for the RNN""" """whether to use shortcut for the RNN"""
...@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module): ...@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history: bool = True, use_history: bool = True,
card_mask: bool = False, card_mask: bool = False,
rnn_type: str = 'lstm', rnn_type: str = 'lstm',
film: bool = False, film: bool = True,
noam: bool = False, noam: bool = True,
rwkv_head_size: int = 32, rwkv_head_size: int = 32,
action_feats: bool = True, action_feats: bool = True,
rnn_shortcut: bool = False, rnn_shortcut: bool = False,
batch_norm: bool = False, batch_norm: bool = False,
critic_width: int = 128, critic_width: int = 128,
critic_depth: int = 3, critic_depth: int = 3,
version: int = 0, version: int = 2,
q_head: bool = False, q_head: bool = False,
switch: bool = True, switch: bool = True,
freeze_id: bool = False, freeze_id: bool = False,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment