Commit 662b300f authored by sbl1996@126.com's avatar sbl1996@126.com

Update doc and defaults for release

parent 03416f14
This diff is collapsed.
...@@ -131,7 +131,7 @@ if __name__ == "__main__": ...@@ -131,7 +131,7 @@ if __name__ == "__main__":
seed = args.seed + 100000 seed = args.seed + 100000
random.seed(seed) random.seed(seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -165,6 +165,7 @@ if __name__ == "__main__": ...@@ -165,6 +165,7 @@ if __name__ == "__main__":
oppo_info=args.oppo_info, oppo_info=args.oppo_info,
**env_option, **env_option,
) )
envs1.num_envs = num_envs
envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info) envs1 = EnvPreprocess(envs1, skip_mask=not args.oppo_info)
if cross_env: if cross_env:
...@@ -175,11 +176,11 @@ if __name__ == "__main__": ...@@ -175,11 +176,11 @@ if __name__ == "__main__":
deck2=deck2, deck2=deck2,
**env_option, **env_option,
) )
envs2.num_envs = num_envs
key = jax.random.PRNGKey(seed) key = jax.random.PRNGKey(seed)
obs_space1 = envs1.observation_space obs_space1 = envs1.observation_space
envs1.num_envs = num_envs
envs1 = RecordEpisodeStatistics(envs1) envs1 = RecordEpisodeStatistics(envs1)
sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample()) sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample())
agent1 = create_agent1(args) agent1 = create_agent1(args)
...@@ -190,7 +191,6 @@ if __name__ == "__main__": ...@@ -190,7 +191,6 @@ if __name__ == "__main__":
if cross_env: if cross_env:
obs_space2 = envs2.observation_space obs_space2 = envs2.observation_space
envs2.num_envs = num_envs
envs2 = RecordEpisodeStatistics(envs2) envs2 = RecordEpisodeStatistics(envs2)
sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample()) sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample())
else: else:
......
...@@ -106,7 +106,7 @@ class Args: ...@@ -106,7 +106,7 @@ class Args:
"""the discount factor gamma""" """the discount factor gamma"""
num_minibatches: int = 64 num_minibatches: int = 64
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2 update_epochs: int = 1
"""the K epochs to update the policy""" """the K epochs to update the policy"""
switch: bool = False switch: bool = False
"""Toggle the use of switch mechanism""" """Toggle the use of switch mechanism"""
...@@ -119,7 +119,7 @@ class Args: ...@@ -119,7 +119,7 @@ class Args:
"""Toggle the use of UPGO for advantages""" """Toggle the use of UPGO for advantages"""
sep_value: bool = True sep_value: bool = True
"""Whether separate value function computation for each player""" """Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace" value: Literal["vtrace", "gae"] = "gae"
"""the method to learn the value function""" """the method to learn the value function"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
...@@ -715,14 +715,14 @@ def main(): ...@@ -715,14 +715,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -716,14 +716,14 @@ def main(): ...@@ -716,14 +716,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -743,14 +743,14 @@ def main(): ...@@ -743,14 +743,14 @@ def main():
# seeding # seeding
random.seed(args.seed) random.seed(args.seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
seed_offset = args.local_rank seed_offset = args.local_rank
seed += seed_offset seed += seed_offset
init_key = jax.random.PRNGKey(seed - seed_offset) init_key = jax.random.PRNGKey(seed - seed_offset)
random.seed(seed) random.seed(seed)
args.real_seed = random.randint(0, 1e8) args.real_seed = random.randint(0, int(1e8))
key = jax.random.PRNGKey(args.real_seed) key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1) key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
......
...@@ -96,7 +96,7 @@ if __name__ == "__main__": ...@@ -96,7 +96,7 @@ if __name__ == "__main__":
seed = args.seed + 100000 seed = args.seed + 100000
random.seed(seed) random.seed(seed)
seed = random.randint(0, 1e8) seed = random.randint(0, int(1e8))
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
......
...@@ -83,11 +83,10 @@ class CardEncoder(nn.Module): ...@@ -83,11 +83,10 @@ class CardEncoder(nn.Module):
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
oppo_info: bool = False oppo_info: bool = False
version: int = 0 version: int = 2
@nn.compact @nn.compact
def __call__(self, x_id, x, mask): def __call__(self, x_id, x, mask):
assert self.version > 0
c = self.channels c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype) layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype)
...@@ -105,13 +104,6 @@ class CardEncoder(nn.Module): ...@@ -105,13 +104,6 @@ class CardEncoder(nn.Module):
x_loc = x1[:, :, 0] x_loc = x1[:, :, 0]
x_seq = x1[:, :, 1] x_seq = x1[:, :, 1]
if self.version == 0:
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
f_loc = layer_norm()(embed(9, c)(x_loc))
f_seq = layer_norm()(embed(76, c)(x_seq))
c_mask = x_loc == 0 c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False) c_mask = c_mask.at[:, 0].set(False)
...@@ -130,16 +122,6 @@ class CardEncoder(nn.Module): ...@@ -130,16 +122,6 @@ class CardEncoder(nn.Module):
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def) x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:]) x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
if self.version == 0:
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
f_cards_g = None
else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id) x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id) x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc) f_loc = embed(9, c // 16 * 2)(x_loc)
...@@ -175,7 +157,7 @@ class GlobalEncoder(nn.Module): ...@@ -175,7 +157,7 @@ class GlobalEncoder(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
version: int = 0 version: int = 2
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -230,7 +212,7 @@ class Encoder(nn.Module): ...@@ -230,7 +212,7 @@ class Encoder(nn.Module):
noam: bool = False noam: bool = False
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False oppo_info: bool = False
version: int = 0 version: int = 2
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -252,7 +234,7 @@ class Encoder(nn.Module): ...@@ -252,7 +234,7 @@ class Encoder(nn.Module):
card_encoder = CardEncoder( card_encoder = CardEncoder(
channels=c, dtype=self.dtype, param_dtype=self.param_dtype, channels=c, dtype=self.dtype, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info) version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1 ActionEncoderCls = ActionEncoderV1
action_encoder = ActionEncoderCls( action_encoder = ActionEncoderCls(
channels=c, dtype=self.dtype, param_dtype=self.param_dtype) channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
...@@ -313,33 +295,6 @@ class Encoder(nn.Module): ...@@ -313,33 +295,6 @@ class Encoder(nn.Module):
# History actions # History actions
x_h_actions = x_h_actions.astype(jnp.int32) x_h_actions = x_h_actions.astype(jnp.int32)
if self.version == 0:
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = id_embed(x_h_id)
if self.freeze_id:
x_h_id = jax.lax.stop_gradient(x_h_id)
x_h_id = MLP(
(c, c), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_h_id)
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm()(f_h_actions[:, 0])
else:
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0 h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False) h_mask = h_mask.at[:, 0].set(False)
...@@ -379,28 +334,6 @@ class Encoder(nn.Module): ...@@ -379,28 +334,6 @@ class Encoder(nn.Module):
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype) f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1) f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
if self.version == 0:
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
if not self.use_history:
f_g_h_actions = jnp.zeros_like(f_g_h_actions)
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
spec_index = x_actions[..., 0] spec_index = x_actions[..., 0]
B = jnp.arange(batch_size) B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index] f_a_cards = f_cards[B[:, None], spec_index]
...@@ -436,6 +369,7 @@ class Encoder(nn.Module): ...@@ -436,6 +369,7 @@ class Encoder(nn.Module):
g_feats.append(f_g_actions) g_feats.append(f_g_actions)
f_state = jnp.concatenate(g_feats, axis=-1) f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c oc = self.out_channels or c
if self.version == 2: if self.version == 2:
f_state = GLUMlp( f_state = GLUMlp(
...@@ -573,7 +507,7 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False) ...@@ -573,7 +507,7 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False)
return rstate, f_state return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True, return_state=False): def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=False, return_state=False):
if switch: if switch:
def body_fn(cell, carry, x, done, switch): def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry rstate, init_rstate2 = carry
...@@ -601,11 +535,11 @@ class EncoderArgs: ...@@ -601,11 +535,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
card_mask: bool = False card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer""" """whether to mask the padding card as ignored in the transformer"""
noam: bool = False noam: bool = True
"""whether to use Noam architecture for the transformer layer""" """whether to use Noam architecture for the transformer layer"""
action_feats: bool = True action_feats: bool = True
"""whether to use action features for the global state""" """whether to use action features for the global state"""
version: int = 0 version: int = 2
"""the version of the environment and the agent""" """the version of the environment and the agent"""
...@@ -615,7 +549,7 @@ class ModelArgs(EncoderArgs): ...@@ -615,7 +549,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent""" """the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm" rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = True
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
oppo_info: bool = False oppo_info: bool = False
"""whether to use opponent's information""" """whether to use opponent's information"""
...@@ -638,8 +572,8 @@ class RNNAgent(nn.Module): ...@@ -638,8 +572,8 @@ class RNNAgent(nn.Module):
use_history: bool = True use_history: bool = True
card_mask: bool = False card_mask: bool = False
rnn_type: str = 'lstm' rnn_type: str = 'lstm'
film: bool = False film: bool = True
noam: bool = False noam: bool = True
rwkv_head_size: int = 32 rwkv_head_size: int = 32
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False oppo_info: bool = False
...@@ -647,10 +581,10 @@ class RNNAgent(nn.Module): ...@@ -647,10 +581,10 @@ class RNNAgent(nn.Module):
batch_norm: bool = False batch_norm: bool = False
critic_width: int = 128 critic_width: int = 128
critic_depth: int = 3 critic_depth: int = 3
version: int = 0 version: int = 2
q_head: bool = False q_head: bool = False
switch: bool = True switch: bool = False
freeze_id: bool = False freeze_id: bool = False
int_head: bool = False int_head: bool = False
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
......
...@@ -646,11 +646,11 @@ class EncoderArgs: ...@@ -646,11 +646,11 @@ class EncoderArgs:
"""whether to use history actions as input for agent""" """whether to use history actions as input for agent"""
card_mask: bool = False card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer""" """whether to mask the padding card as ignored in the transformer"""
noam: bool = False noam: bool = True
"""whether to use Noam architecture for the transformer layer""" """whether to use Noam architecture for the transformer layer"""
action_feats: bool = True action_feats: bool = True
"""whether to use action features for the global state""" """whether to use action features for the global state"""
version: int = 0 version: int = 2
"""the version of the environment and the agent""" """the version of the environment and the agent"""
...@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs): ...@@ -660,7 +660,7 @@ class ModelArgs(EncoderArgs):
"""the number of channels for the RNN in the agent""" """the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm" rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN""" """the type of RNN to use, None for no RNN"""
film: bool = False film: bool = True
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
rnn_shortcut: bool = False rnn_shortcut: bool = False
"""whether to use shortcut for the RNN""" """whether to use shortcut for the RNN"""
...@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module): ...@@ -684,15 +684,15 @@ class RNNAgent(nnx.Module):
use_history: bool = True, use_history: bool = True,
card_mask: bool = False, card_mask: bool = False,
rnn_type: str = 'lstm', rnn_type: str = 'lstm',
film: bool = False, film: bool = True,
noam: bool = False, noam: bool = True,
rwkv_head_size: int = 32, rwkv_head_size: int = 32,
action_feats: bool = True, action_feats: bool = True,
rnn_shortcut: bool = False, rnn_shortcut: bool = False,
batch_norm: bool = False, batch_norm: bool = False,
critic_width: int = 128, critic_width: int = 128,
critic_depth: int = 3, critic_depth: int = 3,
version: int = 0, version: int = 2,
q_head: bool = False, q_head: bool = False,
switch: bool = True, switch: bool = True,
freeze_id: bool = False, freeze_id: bool = False,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment