Commit a932ce02 authored by sbl1996@126.com's avatar sbl1996@126.com

Support more cards

parent 03b43512
......@@ -55,3 +55,4 @@
03814632
72860663
!side
27204312
......@@ -56,17 +56,7 @@
41999284
41999284
!side
75732622
15397015
15397015
73642296
23434538
5821478
77058170
3679218
25774450
43898403
23002292
23002292
84749824
84749824
73915052
73915053
73915054
73915055
#created by wxapp_ygo
#main
93454062
93454062
93454062
07478431
23434538
23434538
23434538
14558127
14558127
14558127
29942771
29942771
29942771
35726888
92107604
24224830
24299458
94445733
66712905
66712905
68957034
68957034
68957034
30430448
30430448
20618850
20618850
67835547
67835547
93229151
93229151
93229151
31562086
31562086
34813545
34813545
34813545
03734202
03734202
03734202
#extra
55990317
55990317
55990317
28373620
28373620
33198837
42566602
52445243
80666118
87188910
84815190
96633955
66011101
90590303
08728498
!side
70902743
\ No newline at end of file
......@@ -54,7 +54,6 @@
94977269
52687916
73580471
31924889
56832966
84013237
82633039
......
#created by ...
#main
27204311
48452496
48452496
72270339
72270339
72270339
14558127
14558127
14558127
23434538
23434538
23434538
35405755
45663742
9674034
9674034
9674034
90241276
90241276
9742784
12058741
94145021
94145021
97268402
97268402
97268402
2295440
24081957
89023486
89023486
24224830
24224830
26700718
80845034
80845034
80845034
53639887
10045474
10045474
38511382
#extra
84815190
27548199
79606837
50091196
98127546
20665527
45112597
4280258
2772337
61245672
48815792
87871125
27381364
65741786
41999284
!side
27204312
......@@ -57,18 +57,5 @@
08491308
12421694
!side
19613556
12580477
12580477
04031928
04031928
18144506
35269904
08267140
08267140
31849106
31849106
83326048
83326048
15693423
15693423
52340445
27204312
#created by ...
#main
27204311
48452496
66431519
2526224
2526224
2526224
72270339
18621798
14558127
14558127
14558127
23434538
23434538
23434538
9674034
9674034
9674034
90241276
90241276
90241276
90681088
94145021
97268402
97268402
24081957
85106525
85106525
85106525
89023486
24224830
24224830
80845034
80845034
80845034
91703676
65305978
57554544
10045474
10045474
10045474
#extra
93039339
64182380
57134592
20665527
45112597
4280258
2772337
2772337
61245672
8264361
48815792
87871125
29301450
65741786
41999284
!side
27204312
......@@ -57,3 +57,8 @@
32519092
78917791
!side
20001444
27204312
93490857
56495148
14821891
#created by ...
#main
27204311
27204311
27204311
26866984
88284599
51296484
51296484
51296484
14558127
14558127
14558127
92919429
92919429
92919429
23434538
23434538
23434538
25801745
25801745
25801745
4810828
10804018
10774240
10774240
13048472
13048472
13048472
25311006
49238328
49238328
52472775
52472775
24224830
24224830
39114494
98477480
98477480
98477480
10045474
10045474
86310763
#extra
80532587
22850702
22850702
79606837
93039339
98127546
73898890
73898890
9839945
29301450
29301450
29301450
71818935
41999284
94259633
!side
27204312
......@@ -10,6 +10,7 @@
- attribute: 1, int, 0: N/A, same as attribute2str[2:]
- race: 1, int, 0: N/A, same as race2str
- level: 1, int, 0: N/A
- counter: 1, int, 0: N/A
- atk: 2, max 65535 to 2 bytes
- def: 2, max 65535 to 2 bytes
- type: 25, multi-hot, same as type2str
......
......@@ -2,17 +2,11 @@
## Unsupported
- Many (Crossout Designator)
- Blackwing (add_counter)
- Magician (pendulum)
- Shaddoll (add_counter)
- Shiranui (Fairy Tail - Snow)
- Hero (random_selected)
# Messgae
## random_selected
Not supported
## add_counter
Not supported
......@@ -35,14 +29,5 @@ Only 1 attribute is announced at a time.
Not supported:
- DNA Checkup
# Summon
## Tribute Summon
Through `select_tribute` (multi-select)
## Link Summon
Through `select_unselect_card` (select 1 card per time)
## Syncro Summon
- `select_card` to choose the tuner (usually 1 card)
- `select_sum` to choose the non-tuner (1 card per time)
## announce_number
Only 1-12 is supported.
......@@ -454,3 +454,205 @@
13764603
75524093
32519092
89870349
56733747
19324993
18094166
52947044
93347961
75047173
30757127
40044918
21143940
16605586
45906428
59392529
40854197
60461804
9411399
58004362
58481572
1948619
46759931
32828466
90590303
83965310
14124483
27780618
50720316
22865492
89943723
8949584
22908820
24094653
55990317
52445243
93229151
94445733
31562086
28373620
3734202
33198837
92107604
68957034
42566602
29942771
7478431
30430448
66712905
87188910
80666118
20618850
34813545
67835547
93454062
8728498
24299458
66011101
56399890
77058170
27572350
15397015
51452091
73642296
43262273
25774450
8233522
43455065
31849106
34267821
25789292
82385847
70902743
3679218
5821478
83326048
23002292
8267140
4031928
11109820
84749824
5758500
15693423
75732622
19613556
12058741
61245672
80845034
90241276
9742784
72270339
27548199
53639887
38511382
48452496
26700718
65741786
89023486
79606837
20665527
50091196
45112597
9674034
2772337
45663742
24081957
27381364
87871125
48815792
98127546
35405755
4280258
73347079
2009101
81983656
85215458
73652465
58820853
81105204
91351370
82633039
16195942
17377751
27243130
95040215
1475311
53582587
75498415
78156759
52687916
86848580
23338098
49003716
14785765
69031175
59839761
76913983
16051717
72930878
53567095
14087893
33236860
73580471
22835145
5318639
11827244
4939890
3717252
84013237
50907446
48130397
23912837
84433295
19261966
97518132
31924889
74822425
20366274
56832966
77505534
69764158
34710660
6417578
59546797
51023024
94977269
40605147
37445295
4904633
30328508
44394295
77723643
48424886
24635329
91703676
66431519
2526224
18621798
65305978
57554544
90681088
85106525
64182380
8264361
57134592
4810828
51296484
49238328
94259633
13048472
73898890
52472775
25801745
80532587
26866984
9839945
98477480
88284599
92919429
71818935
10774240
39114494
86310763
84211599
10804018
......@@ -71,13 +71,13 @@ class Args:
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: str = "checkpoints/agent.pt"
checkpoint: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load"""
embedding_file: Optional[str] = "embeddings_en.npy"
"""the embedding file for card embeddings"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
......@@ -137,24 +137,29 @@ if __name__ == "__main__":
envs = RecordEpisodeStatistics(envs)
if args.agent:
if args.embedding_file:
embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
# count lines of code_list
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
else:
state_dict = None
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
if state_dict:
agent.load_state_dict(state_dict)
else:
prefix = "_orig_mod."
if state_dict:
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict)
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
......@@ -220,7 +225,7 @@ if __name__ == "__main__":
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward == -1:
if episode_reward < 0:
win = 0
else:
win = 1
......@@ -230,7 +235,6 @@ if __name__ == "__main__":
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}")
if len(episode_lengths) >= args.num_episodes:
break
......
import os
import time
from dataclasses import dataclass
from typing import Optional
import numpy as np
......@@ -17,7 +18,7 @@ class Args:
"""the directory of ydk files"""
code_list_file: str = "code_list.txt"
"""the file containing the list of card codes"""
embeddings_file: str = "embeddings.npy"
embeddings_file: Optional[str] = "embeddings.npy"
"""the npz file containing the embeddings of the cards"""
cards_db: str = "../assets/locale/en/cards.cdb"
"""the cards database file"""
......@@ -77,6 +78,7 @@ if __name__ == "__main__":
code_list = [int(code.strip()) for code in code_list]
print(f"The database contains {len(code_list)} cards.")
if embeddings_file is not None:
# read embeddings
if not os.path.exists(embeddings_file):
sample_embedding = get_embeddings(["test"])[0]
......@@ -99,17 +101,15 @@ if __name__ == "__main__":
exit()
new_texts = read_texts(cards_db, new_codes)
print(new_texts)
if embeddings_file is not None:
embeddings = get_embeddings(new_texts, args.batch_size, args.wait_time, verbose=True)
# add new embeddings
all_embeddings = np.concatenate([all_embeddings, np.array(embeddings)], axis=0)
np.save(embeddings_file, all_embeddings)
# update code_list
code_list += new_codes
# save embeddings and code_list
np.save(embeddings_file, all_embeddings)
with open(code_list_file, "w") as f:
f.write("\n".join(map(str, code_list)) + "\n")
......
......@@ -45,7 +45,9 @@ class Encoder(nn.Module):
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None:
n_embed, embed_dim = 1000, 1024
n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int):
n_embed, embed_dim = embedding_shape, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
......@@ -55,12 +57,13 @@ class Encoder(nn.Module):
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
self.owner_embed = nn.Embedding(2, c // 16 * 2)
self.owner_embed = nn.Embedding(2, c // 16)
self.position_embed = nn.Embedding(9, c // 16 * 2)
self.overley_embed = nn.Embedding(2, c // 16)
self.attribute_embed = nn.Embedding(8, c // 16)
self.race_embed = nn.Embedding(27, c // 16)
self.level_embed = nn.Embedding(14, c // 16)
self.counter_embed = nn.Embedding(16, c // 16)
self.type_fc_emb = linear(25, c // 16 * 2)
self.atk_fc_emb = linear(c_num, c // 16)
self.def_fc_emb = linear(c_num, c // 16)
......@@ -98,8 +101,9 @@ class Encoder(nn.Module):
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(5, c // divisor)
self.a_option_embed = nn.Embedding(4, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_option_embed = nn.Embedding(4, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
self.a_attrib_embed = nn.Embedding(31, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
......@@ -164,9 +168,10 @@ class Encoder(nn.Module):
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_place = self.a_place_embed(x[:, :, 7])
x_a_attrib = self.a_attrib_embed(x[:, :, 8])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_place, x_a_attrib
x_a_number = self.a_number_embed(x[:, :, 7])
x_a_place = self.a_place_embed(x[:, :, 8])
x_a_attrib = self.a_attrib_embed(x[:, :, 9])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
......@@ -211,7 +216,8 @@ class Encoder(nn.Module):
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level
x_counter = self.counter_embed(x1[:, :, 8])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level, x_counter
def encode_card_feat2(self, x2):
x_atk = self.num_transform(x2[:, :, 0:2])
......@@ -243,8 +249,8 @@ class Encoder(nn.Module):
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, 2:10].long()
x_cards_2 = x_cards[:, :, 10:].to(torch.float32)
x_cards_1 = x_cards[:, :, 2:11].long()
x_cards_2 = x_cards[:, :, 11:].to(torch.float32)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
......
......@@ -7,8 +7,7 @@ from ygoenv.ygopro import init_module
def load_deck(fn):
with open(fn) as f:
lines = f.readlines()
noside = itertools.takewhile(lambda x: "side" not in x, lines)
deck = [int(line) for line in noside if line[:-1].isdigit()]
deck = [int(line) for line in lines if line[:-1].isdigit()]
return deck
......@@ -25,7 +24,7 @@ _languages = {
"chinese": "zh",
}
def init_ygopro(lang, deck, code_list_file, preload_tokens=True):
def init_ygopro(lang, deck, code_list_file, preload_tokens=False):
short = _languages[lang]
db_path = Path(get_root_directory(), 'assets', 'locale', short, 'cards.cdb')
deck_fp = Path(deck)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment