Commit 662b300f authored by sbl1996@126.com's avatar sbl1996@126.com

Update doc and defaults for release

parent 03416f14
This diff is collapsed.
......@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
random.seed(seed)
np.random.seed(seed)
......@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info=args.oppo_info,
**env_option,
)
envs1.num_envs = num_envs
envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info)
if cross_env:
......@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2=deck2,
**env_option,
)
envs2.num_envs = num_envs
key = jax.random.PRNGKey(seed)
obs_space1 = envs1.observation_space
envs1.num_envs = num_envs
envs1 = RecordEpisodeStatistics(envs1)
sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample())
agent1 = create_agent1(args)
......@@ -190,7 +191,6 @@ if __name__ == "__main__":
if cross_env:
obs_space2 = envs2.observation_space
envs2.num_envs = num_envs
envs2 = RecordEpisodeStatistics(envs2)
sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample())
else:
......
......@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma"""
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
update_epochs: int = 1
"""the K epochs to update the policy"""
switch: bool = False
"""Toggle the use of switch mechanism"""
......@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages"""
sep_value: bool = True
"""Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace"
value: Literal["vtrace", "gae"] = "gae"
"""the method to learn the value function"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
......@@ -715,14 +715,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -716,14 +716,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -743,14 +743,14 @@ def main():
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
seed_offset = args.local_rank
seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
......@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
seed = random.randint(0, int(1e8))
random.seed(seed)
np.random.seed(seed)
......
This diff is collapsed.
......@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
noam: bool = False
noam: bool = True
"""whether to use Noam architecture for the transformer layer"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
version: int = 2
"""the version of the environment and the agent"""
......@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
film: bool = True
"""whether to use FiLM for the actor"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
......@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history: bool = True,
card_mask: bool = False,
rnn_type: str = 'lstm',
film: bool = False,
noam: bool = False,
film: bool = True,
noam: bool = True,
rwkv_head_size: int = 32,
action_feats: bool = True,
rnn_shortcut: bool = False,
batch_norm: bool = False,
critic_width: int = 128,
critic_depth: int = 3,
version: int = 0,
version: int = 2,
q_head: bool = False,
switch: bool = True,
freeze_id: bool = False,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment