Commit f8a13371 authored by sbl1996@126.com's avatar sbl1996@126.com

Agent prepared for multi select

parent e319152f
......@@ -72,10 +72,6 @@ python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embedding
## TODO
### Documentation
- Add documentations of building and running
### Training
- Evaluation with old models during training
- LSTM for memory
......@@ -85,6 +81,7 @@ python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embedding
- MCTS-based planning
- Support of play in YGOPro
## Related Projects
TODO
\ No newline at end of file
### Related Projects
- [yugioh-ai](https://github.com/melvinzhang/yugioh-ai])
- [yugioh-game](https://github.com/tspivey/yugioh-game)
- [envpool](https://github.com/sail-sg/envpool)
\ No newline at end of file
......@@ -4,7 +4,7 @@
- float transform: max 65535 -> 2 bytes
- count
## Card (39)
## Card
- 0,1: card id, uint16 -> 2 uint8, name+desc
- 2: location, discrete, 0: N/A, 1+: same as location2str (9)
- 3: seq, discrete, 0: N/A, 1+: seq in location
......@@ -44,10 +44,10 @@
- 22: is_end, discrete, 0: False, 1: True
## Legal Actions (max 24)
- 0,1: spec index or card id, uint16 -> 2 uint8
- 2: msg, discrete, 0: N/A, 1+: same as msg2str (11)
- act: 1, int (11)
## Legal Actions
- 0,1: spec index, uint16 -> 2 uint8
- 2: msg, discrete, 0: N/A, 1+: same as msg2str (15)
- 3: act, discrete (11)
- N/A
- t: Set
- r: Reposition
......@@ -59,32 +59,33 @@
- v2: Activate the second effect
- v3: Activate the third effect
- v4: Activate the fourth effect
- yes/no: 1, int (3)
- 4: yes/no, discrete (3)
- N/A
- Yes
- No
- phase: 1, int (4)
- 5: phase, discrete (4)
- N/A
- Battle (b)
- Main Phase 2 (m)
- End Phase (e)
- cancel: 1
- 6: cancel, discrete (2)
- N/A
- Cancel
- finish: 1
- 7: finish, discrete (2)
- N/A
- Finish
- position: 1, int , 0: N/A, same as position2str
- option: 1, int, 0: N/A
- number: 1, int, 0: N/A
- place: 1, int (31), 0: N/A,
- 8: position, discrete, 0: N/A, same as position2str
- 9: option, discrete, 0: N/A
- 10: number, discrete, 0: N/A
- 11: place, discrete
- 0: N/A
- 1-7: m
- 8-15: s
- 16-22: om
- 23-30: os
- attribute: 1, int, 0: N/A, same as attribute2id
- 12: attribute, discrete, 0: N/A, same as attribute2id
## History Actions
- id: 2x4, uint16 -> 2 uint8, name+desc
- same as Legal Actions
- 0,1: card id, uint16 -> 2 uint8
- others same as legal actions
......@@ -142,7 +142,7 @@ if __name__ == "__main__":
envs = RecordEpisodeStatistics(envs)
if args.agent:
if args.checkpoint.endswith(".ptj"):
if args.checkpoint and args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
......
......@@ -12,7 +12,10 @@ AUTHOR = 'Hastur'
REQUIRES_PYTHON = '>=3.8.0'
VERSION = None
REQUIRED = []
REQUIRED = [
"tyro",
"pandas",
]
here = os.path.dirname(os.path.abspath(__file__))
......
......@@ -80,7 +80,11 @@ class Encoder(nn.Module):
n_embed = 1 + n_embed # 1 (index 0) for unknown
self.id_embed = nn.Embedding(n_embed, embed_dim)
self.id_fc_emb = linear(1024, c // 4)
self.id_fc_emb = nn.Sequential(
linear(embed_dim, c),
nn.ReLU(),
linear(c, c // 4),
)
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
......@@ -115,8 +119,6 @@ class Encoder(nn.Module):
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.my_deck_fc_emb = linear(1024, c // 4)
self.global_norm_pre = nn.LayerNorm(c * 2, elementwise_affine=affine)
self.global_net = nn.Sequential(
nn.Linear(c * 2, c * 2),
......@@ -131,8 +133,9 @@ class Encoder(nn.Module):
self.a_act_embed = nn.Embedding(13, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_cancel_embed = nn.Embedding(3, c // divisor)
self.a_finish_embed = nn.Embedding(3, c // divisor // 2)
self.a_position_embed = nn.Embedding(9, c // divisor // 2)
self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
......@@ -147,7 +150,11 @@ class Encoder(nn.Module):
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_fc_emb = nn.Sequential(
linear(embed_dim, c),
nn.ReLU(),
linear(c, c),
)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
......@@ -169,16 +176,26 @@ class Encoder(nn.Module):
self.init_embeddings()
def init_linear(self, m, scale):
nn.init.uniform_(m.weight, -scale, scale)
if m.bias is not None:
nn.init.zeros_(m.bias)
def init_embeddings(self, scale=0.0001):
for n, m in self.named_modules():
if isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -scale, scale)
elif n in ["atk_fc_emb", "def_fc_emb"]:
nn.init.uniform_(m.weight, -scale * 10, scale * 10)
self.init_linear(m, scale * 10)
elif n in ["lp_fc_emb", "oppo_lp_fc_emb"]:
nn.init.uniform_(m.weight, -scale, scale)
self.init_linear(m, scale)
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
if isinstance(m, nn.Linear):
self.init_linear(m, scale)
elif isinstance(m, nn.Sequential):
for mm in m:
if isinstance(mm, nn.Linear):
self.init_linear(mm, scale)
def load_embeddings(self, embeddings):
weight = self.id_embed.weight
......@@ -198,59 +215,43 @@ class Encoder(nn.Module):
x_a_act = self.a_act_embed(x[:, :, 1])
x_a_yesno = self.a_yesno_embed(x[:, :, 2])
x_a_phase = self.a_phase_embed(x[:, :, 3])
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_number = self.a_number_embed(x[:, :, 7])
x_a_place = self.a_place_embed(x[:, :, 8])
x_a_attrib = self.a_attrib_embed(x[:, :, 9])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
x_a_cancel = self.a_cancel_embed(x[:, :, 4])
x_a_finish = self.a_finish_embed(x[:, :, 5])
x_a_position = self.a_position_embed(x[:, :, 6])
x_a_option = self.a_option_embed(x[:, :, 7])
x_a_number = self.a_number_embed(x[:, :, 8])
x_a_place = self.a_place_embed(x[:, :, 9])
x_a_attrib = self.a_attrib_embed(x[:, :, 10])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
m = c // 2
spec_index = x.view(b, n, m, 2)
spec_index = spec_index[..., 0] * 256 + spec_index[..., 1]
mask = spec_index != 0
mask[:, :, 0] = True
spec_index = spec_index.view(b, -1)
b = x.shape[0]
spec_index = x[:, :, 0] * 256 + x[:, :, 1]
B = torch.arange(b, device=spec_index.device)
f_a_actions = f_cards[B[:, None], spec_index]
f_a_actions = f_a_actions.view(b, n, m, -1)
f_a_actions = (f_a_actions * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return f_a_actions
def get_h_action_card_(self, x):
b, n, _ = x.shape
x_ids = x.view(b, n, -1, 2)
x_ids = x_ids[..., 0] * 256 + x_ids[..., 1]
mask = x_ids != 0
mask[:, :, 0] = True
x_ids = x[:, :, 0] * 256 + x[:, :, 1]
x_ids = self.id_embed(x_ids)
x_ids = self.h_id_fc_emb(x_ids)
x_ids = (x_ids * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return x_ids
def encode_card_id(self, x):
x_id = self.id_embed(x)
x_id = self.id_fc_emb(x_id)
x_id = self.id_norm(x_id)
return x_id
x_ids = x[:, :, 0] * 256 + x[:, :, 1]
x_ids = self.id_embed(x_ids)
x_ids = self.id_fc_emb(x_ids)
return x_ids
def encode_card_feat1(self, x1):
x_owner = self.owner_embed(x1[:, :, 2])
x_position = self.position_embed(x1[:, :, 3])
x_overley = self.overley_embed(x1[:, :, 4])
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
x_counter = self.counter_embed(x1[:, :, 8])
x_negated = self.negated_embed(x1[:, :, 9])
x_owner = self.owner_embed(x1[:, :, 4])
x_position = self.position_embed(x1[:, :, 5])
x_overley = self.overley_embed(x1[:, :, 6])
x_attribute = self.attribute_embed(x1[:, :, 7])
x_race = self.race_embed(x1[:, :, 8])
x_level = self.level_embed(x1[:, :, 9])
x_counter = self.counter_embed(x1[:, :, 10])
x_negated = self.negated_embed(x1[:, :, 11])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level, x_counter, x_negated
def encode_card_feat2(self, x2):
......@@ -287,15 +288,14 @@ class Encoder(nn.Module):
x_global = x['global_']
x_actions = x['actions_']
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, :12].long()
x_cards_2 = x_cards[:, :, 12:].to(torch.float32)
x_cards_1 = x_cards[:, :, 2:11].long()
x_cards_2 = x_cards[:, :, 11:].to(torch.float32)
x_id = self.encode_card_id(x_cards_1[:, :, :2])
x_id = self.id_norm(x_id)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 1]))
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 2]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3]))
x_feat1 = self.encode_card_feat1(x_cards_1)
x_feat2 = self.encode_card_feat2(x_cards_2)
......@@ -323,16 +323,14 @@ class Encoder(nn.Module):
x_actions = x_actions.long()
max_multi_select = (x_actions.shape[-1] - 9) // 2
mo = max_multi_select * 2
f_a_cards = self.get_action_card_(x_actions[..., :mo], f_cards)
f_a_cards = self.get_action_card_(x_actions[..., :2], f_cards)
f_a_cards = f_a_cards + self.a_card_proj(self.a_card_norm(f_a_cards))
x_a_feats = self.encode_action_(x_actions[..., mo:])
x_a_feats = self.encode_action_(x_actions[..., 2:])
x_a_feats = torch.cat(x_a_feats, dim=-1)
f_actions = f_a_cards + self.a_feat_norm(x_a_feats)
mask = x_actions[:, :, mo] == 0 # msg == 0
mask = x_actions[:, :, 2] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
......@@ -342,9 +340,9 @@ class Encoder(nn.Module):
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :mo])
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment