Commit 0559e98c authored by biluo.shen's avatar biluo.shen

Update agent

parent b1054104
...@@ -30,7 +30,7 @@ class Args: ...@@ -30,7 +30,7 @@ class Args:
"""the name of this experiment""" """the name of this experiment"""
seed: int = 1 seed: int = 1
"""seed of the experiment""" """seed of the experiment"""
torch_deterministic: bool = True torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`""" """if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True cuda: bool = True
"""if toggled, cuda will be enabled by default""" """if toggled, cuda will be enabled by default"""
......
...@@ -29,7 +29,7 @@ class Encoder(nn.Module): ...@@ -29,7 +29,7 @@ class Encoder(nn.Module):
c = channels c = channels
self.loc_embed = nn.Embedding(9, c) self.loc_embed = nn.Embedding(9, c)
self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine) self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.seq_embed = nn.Embedding(61, c) self.seq_embed = nn.Embedding(76, c)
self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine) self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine)
linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias) linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias)
...@@ -83,7 +83,7 @@ class Encoder(nn.Module): ...@@ -83,7 +83,7 @@ class Encoder(nn.Module):
self.lp_fc_emb = linear(c_num, c // 4) self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4) self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.turn_embed = nn.Embedding(20, c // 8) self.turn_embed = nn.Embedding(20, c // 8)
self.phase_embed = nn.Embedding(10, c // 8) self.phase_embed = nn.Embedding(11, c // 8)
self.if_first_embed = nn.Embedding(2, c // 8) self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8) self.is_my_turn_embed = nn.Embedding(2, c // 8)
...@@ -97,15 +97,15 @@ class Encoder(nn.Module): ...@@ -97,15 +97,15 @@ class Encoder(nn.Module):
divisor = 8 divisor = 8
self.a_msg_embed = nn.Embedding(30, c // divisor) self.a_msg_embed = nn.Embedding(30, c // divisor)
self.a_act_embed = nn.Embedding(11, c // divisor) self.a_act_embed = nn.Embedding(13, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor) self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor) self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor) self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor) self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_option_embed = nn.Embedding(4, c // divisor // 2) self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2) self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2) self.a_place_embed = nn.Embedding(31, c // divisor // 2)
self.a_attrib_embed = nn.Embedding(31, c // divisor // 2) self.a_attrib_embed = nn.Embedding(10, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine) self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False) self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False)
......
...@@ -36,7 +36,6 @@ def setup(backend, rank, world_size, port): ...@@ -36,7 +36,6 @@ def setup(backend, rank, world_size, port):
def mp_start(run): def mp_start(run):
mp.set_start_method('forkserver')
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1: if world_size == 1:
run(local_rank=0, world_size=world_size) run(local_rank=0, world_size=world_size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment