Commit 932b6e39 authored by sbl1996@126.com's avatar sbl1996@126.com

Merge yugioh-ai and envpool2

parents
# Xmake cache
.xmake/
# MacOS Cache
.DS_Store
*.out
code_list.txt
*.npy
.vscode/
checkpoints/
runs/
logs/
k8s_job/
script
assets/locale/*/*.cdb
assets/locale/*/strings.conf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
*.o
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
\ No newline at end of file
yugioh-ai
---------
MIT License
Copyright (c) 2024 Hastur
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Modified from tspivey's yugioh-game <https://github.com/tspivey/yugioh-game>
------------------------------------------------------------------------------
MIT License
Copyright (c) 2017 Tyler Spivey
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Yugioh AI
Yugioh AI uses large language models (LLM) and RL to play Yu-Gi-Oh. It is inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and uses [ygopro-core](https://github.com/Fluorohydride/ygopro-core).
## Usage
### Setup
An automated setup script is provided for Linux (tested on Ubuntu 22.04). It will install all dependencies and build the library. To run it, execute the following commands:
```bash
make setup
```
### Running
Test that the repo is setup correctly by running:
```
python cli.py --deck1 deck/Starter.ydk --deck2 deck/BlueEyes.ydk
```
You should see text output showing two random AI playing a duel by making random moves.
You can set `--seed` to a fixed value to get deterministic results.
## Implementation
The implementation is initially based on [yugioh-game](https://github.com/tspivey/yugioh-game). To provide a clean game environment, it removes all server code and only keeps basic duel-related classes and message handlers.
To implement the AI, inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]), every message handler also provides all possible actions that can be taken in response to the message.
## Notes
Never
\ No newline at end of file
#created by ...
#main
89631139
89631139
89631139
38517737
38517737
38517737
45467446
23434538
23434538
23434538
71039903
71039903
71039903
45644898
45644898
79814787
8240199
8240199
8240199
97268402
97268402
2295440
6853254
6853254
6853254
38120068
38120068
38120068
39701395
39701395
41620959
41620959
41620959
48800175
48800175
48800175
43898403
43898403
63356631
63356631
#extra
40908371
40908371
59822133
59822133
50954680
83994433
33698022
39030163
31801517
02978414
63767246
64332231
10443957
41999284
!side
8233522
8233522
14558127
14558127
14558127
56399890
53129443
25789292
25789292
43455065
43455065
43898403
11109820
11109820
11109820
#created by ...
#main
89631139
55410871
89631139
80701178
31036355
38517737
80701178
80701178
95492061
95492061
95492061
53303460
53303460
53303460
14558127
14558127
23434538
55410871
55410871
31036355
31036355
48800175
48800175
48800175
70368879
70368879
70368879
21082832
46052429
46052429
46052429
24224830
24224830
24224830
73915051
10045474
10045474
37576645
37576645
37576645
#extra
31833038
85289965
74997493
5043010
65330383
38342335
2857636
28776350
75452921
3987233
3987233
99111753
98978921
41999284
41999284
!side
75732622
15397015
15397015
73642296
23434538
5821478
77058170
3679218
25774450
43898403
23002292
23002292
84749824
84749824
#created by wxapp_ygo
#main
25451383
25451383
32731036
35984222
60242223
60242223
48048590
62962630
62962630
62962630
68468459
68468459
68468459
45484331
19096726
95515789
45883110
14558127
14558127
14558127
23434538
23434538
23434538
36577931
25311006
35269904
44362883
06498706
06498706
06498706
01984618
01984618
75500286
81439173
82738008
29948294
36637374
24224830
24224830
18973184
01041278
10045474
10045474
10045474
17751597
17751597
#extra
11321089
03410461
72272462
87746184
87746184
41373230
51409648
01906812
24915933
70534340
44146295
44146295
92892239
38811586
53971455
!side
27204311
82385847
34267821
73642296
27572350
93039339
18144506
08267140
08267140
08267140
15693423
15693423
15693423
51452091
23002292
\ No newline at end of file
#created by wxapp_ygo
#main
63941210
63941210
10604644
46659709
46659709
70095154
70095154
70095154
05370235
23434538
23434538
23434538
23893227
23893227
23893227
01142880
01142880
56364287
56364287
56364287
14532163
14532163
86686671
60600126
60600126
60600126
18144506
12580477
63995093
63995093
03659803
37630732
39973386
84797028
84797028
84797028
64753988
10045474
10045474
10045474
#extra
01546123
87116928
74157028
79229522
79229522
84058253
84058253
22850702
90448279
10443957
73964868
58069384
70369116
46724542
60303245
!side
\ No newline at end of file
#created by ...
#main
27204311
10000080
10000080
10000080
95440946
95440946
95440946
14558127
14558127
14558127
23434538
23434538
23434538
68829754
68829754
94224458
31434645
31434645
31434645
10045474
10045474
20612097
20612097
20612097
58921041
58921041
53334471
53334471
90846359
90846359
23516703
23516703
82732705
67007102
93191801
93191801
20590515
20590515
20590515
56984514
#extra
74889525
62541668
90448279
26556950
26096328
49032236
56910167
56910167
56910167
73082255
37129797
03814632
72860663
!side
\ No newline at end of file
#created by ...
#main
6631034
6631034
6631034
43096270
43096270
43096270
69247929
69247929
69247929
77542832
77542832
77542832
11091375
11091375
11091375
35052053
35052053
35052053
49881766
83104731
83104731
30190809
30190809
26412047
26412047
26412047
43422537
43422537
43422537
53129443
66788016
66788016
66788016
72302403
72302403
44095762
44095762
44095762
70342110
70342110
#extra
!side
#created by ...
#main
55623480
52467217
52467217
52467217
92826944
92826944
92826944
41562624
41562624
99423156
99423156
94801854
94801854
94801854
49959355
49959355
49959355
79783880
14558127
14558127
14558127
36630403
36630403
23434538
23434538
23434538
97268402
12580477
18144507
75500286
81439173
13965201
13965201
24224830
24224830
40364916
40364916
4333086
4333086
10045474
10045474
40605147
40605147
41420027
#extra
59843383
27548199
50954680
83283063
74586817
52711246
57288064
26326541
98558751
86066372
72860663
86926989
37129797
91420202
41999284
!side
#created by wxapp_ygo
#main
27204311
27204311
26077387
26077387
26077387
37351133
14558127
14558127
14558127
23434538
23434538
23434538
20357457
20357457
94145021
94145021
73594093
35726888
35726888
99550630
35261759
70368879
70368879
25311006
32807846
63166095
63166095
63166095
24224830
24224830
52340444
98338152
98338152
98338152
51227866
09726840
09726840
24010609
24010609
50005218
#extra
86066372
75147529
29301450
98462037
63013339
63288573
63288573
63288573
90673288
90673288
90673288
08491308
08491308
08491308
12421694
!side
19613556
12580477
12580477
04031928
04031928
18144506
35269904
08267140
08267140
31849106
31849106
83326048
83326048
15693423
15693423
#created by wxapp_ygo
#main
62849088
05141117
93490856
93490856
93490856
55273560
55273560
55273560
20001443
20001443
20001443
56495147
56495147
56495147
59438930
59438930
14558127
14558127
23434538
23434538
83764718
93850690
56465981
56465981
56465981
18144506
83308376
24224830
24224830
65681983
39730727
39730727
10045474
10045474
23068051
23068051
78836195
14821890
14821890
14821890
99137266
#extra
53971455
86682165
84815190
96633955
47710198
09464441
69248256
69248256
93039339
05402805
78917791
29301450
!side
\ No newline at end of file
#created by ...
#main
27204311
06728559
87052196
87052196
23431858
93490856
93490856
93490856
56495147
56495147
56495147
20001443
20001443
20001443
55273560
55273560
55273560
14558127
14558127
14558127
23434538
23434538
23434538
97268402
97268402
97268402
98159737
35261759
35261759
56465981
56465981
56465981
93850690
24224830
24224830
10045474
10045474
10045474
14821890
14821890
#extra
42632209
60465049
96633955
84815190
47710198
9464441
5041348
69248256
69248256
83755611
43202238
!side
#main
176393
645088
904186
1426715
1799465
2625940
2819436
3285552
4417408
7392746
7610395
8025951
8198621
9047461
9396663
9925983
9929399
10389143
11050416
11050417
11050418
11050419
11654068
11738490
12958920
12965762
13536607
13764603
13935002
14089429
14470846
14470847
14821891
14957441
15341822
15341823
15394084
15590356
15629802
16943771
16946850
17000166
17228909
17418745
18027139
18027140
18027141
18494512
19280590
20001444
20368764
21179144
21770261
21830680
22110648
22404676
22411610
22493812
22953212
23116809
23331401
23837055
24874631
25415053
25419324
26326542
27198002
27204312
27450401
27882994
28053764
28062326
28355719
28674153
29491335
29843092
29843093
29843094
30069399
30327675
30327676
30650148
30765616
30811117
31480216
31533705
31600514
31986289
32056071
32335698
32446631
33676147
34479659
34690954
34767866
34822851
35263181
35268888
35514097
35834120
36629636
38030233
38041941
38053382
39972130
40551411
40633085
40703223
40844553
41329459
41456842
42427231
42671152
42956964
43140792
43664495
44026394
44052075
44092305
44097051
44308318
44330099
44586427
44689689
46173680
46173681
46647145
47658965
48068379
48115278
48411997
49752796
49808197
51208047
51611042
51987572
52340445
52900001
53855410
53855411
54537490
55326323
56051649
56495148
56597273
58371672
59160189
59900656
60025884
60406592
60514626
60764582
60764583
62125439
62481204
62543394
63184228
63442605
64213018
64382840
64583601
65500516
65810490
66200211
66661679
67284108
67489920
67922703
67949764
68815402
69550260
69811711
69868556
69890968
70391589
70465811
70875956
70950699
71645243
72291079
73915052
73915053
73915054
73915055
74440056
74627017
74659583
74983882
75119041
75524093
75622825
75732623
76524507
76589547
77672445
78394033
78789357
78836196
79387393
81767889
82255873
82324106
82340057
82556059
82994510
83239740
84816245
85243785
85771020
85771021
85969518
86801872
86871615
87240372
87669905
88923964
89907228
90884404
91512836
93104633
93130022
93224849
93490857
93912846
94703022
94973029
97452818
98596597
98875864
99092625
99137267
#extra
!side
#created by ...
#main
75498415
81105204
81105204
81105204
58820853
58820853
58820853
49003716
49003716
49003716
14785765
85215458
85215458
85215458
2009101
2009101
2009101
22835145
22835145
22835145
73652465
1475311
1475311
53129443
5318639
5318639
14087893
27243130
27243130
91351370
91351370
91351370
53567095
53567095
53567095
53582587
59839761
72930878
72930878
84749824
#extra
52687916
33236860
16051717
23338098
81983656
69031175
73580471
95040215
76913983
17377751
16195942
86848580
82633039
73347079
78156759
!side
#created by wxapp_ygo
#main
18094166
18094166
18094166
22865492
22865492
27780618
27780618
40044918
40044918
09411399
09411399
89943723
16605586
50720316
59392529
14124483
83965310
23434538
23434538
23434538
14558127
14558127
14558127
94145021
94145021
08949584
08949584
08949584
21143940
21143940
21143940
45906428
52947044
24094653
24094653
24224830
24224830
81439173
32807846
75047173
#extra
30757127
89870349
58481572
58481572
22908820
93347961
46759931
40854197
60461804
56733747
32828466
90590303
58004362
19324993
01948619
!side
94145021
34267821
34267821
35269904
35269904
35269904
14532163
14532163
05758500
18144506
08267140
08267140
08267140
43262273
23002292
\ No newline at end of file
#created by ...
#main
14513273
14513273
14513273
76794549
22211622
96227613
96227613
14920218
69610326
69610326
40318957
40318957
40318957
72714461
72714461
72714461
49684352
49684352
49684352
73941492
27204311
14558127
14558127
14558127
23434538
23434538
23434538
01845204
25311006
41620959
41620959
81439173
24224830
24224830
65681983
82190203
82190203
55795155
74850403
10045474
01344018
#extra
43387895
43387895
53262004
76815942
58074177
84815190
30095833
16691074
20665527
04280258
02772337
92812851
45819647
24094258
22125101
!side
27204311
94145021
94145021
94145021
14532163
25311006
25311006
08267140
08267140
08267140
10045474
10045474
15693423
15693423
15693423
#created by ...
#main
3717252
3717252
3717252
77723643
77723643
30328508
30328508
59546797
97518132
34710660
51023024
51023024
4939890
4939890
4939890
59438930
59438930
69764158
24635329
37445295
37445295
23434538
23434538
1475311
11827244
44394295
44394295
44394295
53129443
81439173
6417578
6417578
48130397
23912837
23912837
77505534
77505534
77505534
4904633
40605147
40605147
84749824
#extra
84433295
74822425
74822425
19261966
20366274
48424886
50907446
50907446
94977269
52687916
73580471
31924889
56832966
84013237
82633039
!side
#created by ...
#main
55623480
52467217
52467217
52467217
92826944
92826944
92826944
41562624
41562624
99423156
99423156
94801854
94801854
94801854
49959355
49959355
49959355
79783880
14558127
14558127
14558127
36630403
36630403
23434538
23434538
23434538
97268402
12580477
18144506
75500286
81439173
13965201
13965201
24224830
24224830
40364916
40364916
4333086
4333086
10045474
10045474
40605147
40605147
41420027
#extra
59843383
27548199
50954680
83283063
74586817
52711246
57288064
26326541
98558751
86066372
72860663
86926989
37129797
91420202
41999284
!side
# Dimensions
B: batch size
C: number of channels
H: number of history actions
# Features
f_cards: (B, n_cards, C), features of cards
f_global: (B, C), global features
f_h_actions: (B, H, C), features of history actions
f_actions: (B, max_actions, C), features of current legal actions
output: (B, max_actions, 1), value of each action
# Fusion
## Method 1
```
f_cards -> n encoder layers -> f_cards
f_global -> ResMLP -> f_global
f_cards = f_cards + f_global
f_actions -> n encoder layers -> f_actions
f_cards[id] -> f_a_cards -> ResMLP -> f_a_cards
f_actions = f_a_cards + f_a_feats
f_actions, f_cards -> n decoder layers -> f_actions
f_h_actions -> n encoder layers -> f_h_actions
f_actions, f_h_actions -> n decoder layers -> f_actions
f_actions -> MLP -> output
```
\ No newline at end of file
# Features
## Card (39)
- id: 2, uint16 -> 2 uint8, name+desc
- location: 1, int, 0: N/A, 1+: same as location2str (9)
- seq: 1, int, 0: N/A, 1+: seq in location
- owner: 1, int, 0: me, 1: oppo (2)
- position: 1, int, 0: N/A, same as position2str
- overlay: 1, int, 0: not, 1: xyz material
- attribute: 1, int, 0: N/A, same as attribute2str[2:]
- race: 1, int, 0: N/A, same as race2str
- level: 1, int, 0: N/A
- atk: 2, max 65535 to 2 bytes
- def: 2, max 65535 to 2 bytes
- type: 25, multi-hot, same as type2str
## Global
- lp: 2, max 65535 to 2 bytes
- oppo_lp: 2, max 65535 to 2 bytes
<!-- - turn: 8, int, trunc to 8 -->
- phase: 1, int, one-hot (10)
- is_first: 1, int, 0: False, 1: True
- is_my_turn: 1, int, 0: False, 1: True
- is_end: 1, int, 0: False, 1: True
## Legal Actions (max 8)
- spec index: 8, int, select target
- msg: 1, int (16)
- act: 1, int (11)
- N/A
- t: Set
- r: Reposition
- c: Special Summon
- s: Summon Face-up Attack
- m: Summon Face-down Defense
- a: Attack
- v: Activate
- v2: Activate the second effect
- v3: Activate the third effect
- v4: Activate the fourth effect
- yes/no: 1, int (3)
- N/A
- Yes
- No
- phase: 1, int (4)
- N/A
- Battle (b)
- Main Phase 2 (m)
- End Phase (e)
- cancel_finish: 1, int (3)
- N/A
- Cancel
- Finish
- position: 1, int , 0: N/A, same as position2str
- option: 1, int, 0: N/A
- place: 1, int (31), 0: N/A,
- 1-7: m
- 8-15: s
- 16-22: om
- 23-30: os
- attribute: 1, int, 0: N/A, same as attribute2id
## History Actions
- id: 2x4, uint16 -> 2 uint8, name+desc
- same as Legal Actions
# Deck
## Unsupported
- Many (Crossout Designator)
- Blackwing (add_counter)
- Magician (pendulum)
- Shaddoll (add_counter)
- Shiranui (Fairy Tail - Snow)
- Hero (random_selected)
# Messgae
## random_selected
Not supported
## add_counter
Not supported
## select_card
- `min` and `max` <= 5 are supported
- `min` > 5 throws an error
- `max` > 5 is truncated to 5
### Unsupported
- Fairy Tail - Snow (min=max=7)
- Pot of Prosperity (min=max=6)
## announce_card
Not supported:
- Alsei, the Sylvan High Protector
- Crossout Designator
## announce_attrib
Only 1 attribute is announced at a time.
Not supported:
- DNA Checkup
# Summon
## Tribute Summon
Through `select_tribute` (multi-select)
## Link Summon
Through `select_unselect_card` (select 1 card per time)
## Syncro Summon
- `select_card` to choose the tuner (usually 1 card)
- `select_sum` to choose the non-tuner (1 card per time)
package("ygopro-core")
set_homepage("https://github.com/Fluorohydride/ygopro-core")
set_urls("https://github.com/Fluorohydride/ygopro-core.git")
add_deps("lua")
on_install("linux", function (package)
io.writefile("xmake.lua", [[
add_rules("mode.debug", "mode.release")
add_requires("lua")
target("ygopro-core")
set_kind("static")
add_files("*.cpp")
add_headerfiles("*.h")
add_packages("lua")
]])
local check_and_insert = function(file, line, insert)
local lines = table.to_array(io.lines(file))
if lines[line] ~= insert then
table.insert(lines, line, insert)
io.writefile(file, table.concat(lines, "\n"))
end
end
check_and_insert("field.h", 14, "#include <cstring>")
check_and_insert("interpreter.h", 11, "extern \"C\" {")
check_and_insert("interpreter.h", 15, "}")
local configs = {}
if package:config("shared") then
configs.kind = "shared"
end
import("package.tools.xmake").install(package)
os.cp("*.h", package:installdir("include", "ygopro-core"))
end)
package_end()
\ No newline at end of file
This diff is collapsed.
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import Agent
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
player: int = -1
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
bot_type: Literal["random", "greedy"] = "greedy"
"""the type of bot to use"""
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: str = "checkpoints/agent.pt"
"""the checkpoint to load"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
compile: bool = False
"""if toggled, the model will be compiled"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
if __name__ == "__main__":
args = tyro.cli(Args)
if args.play:
args.num_envs = 1
args.verbose = True
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
if args.agent:
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=args.player,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
verbose=args.verbose,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.agent:
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent = agent.eval()
state_dict = torch.load(args.checkpoint, map_location=device)
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
agent.load_state_dict(state_dict)
else:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict)
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), envs.reset()[0])
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
agent = torch.jit.optimize_for_inference(traced_model)
obs, infos = envs.reset()
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
if args.agent:
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1).cpu().numpy()
model_time += time.time() - _start
else:
if args.strategy == "random":
actions = np.random.randint(infos['num_options'])
else:
actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if infos['to_play'][idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward == -1:
win = 0
else:
win = 1
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}")
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
if not args.play:
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
import os
import time
from dataclasses import dataclass
import numpy as np
import voyageai
import tyro
from ygoai.embed import read_cards
from ygoai.utils import load_deck
@dataclass
class Args:
deck_dir: str = "../assets/deck"
"""the directory of ydk files"""
code_list_file: str = "code_list.txt"
"""the file containing the list of card codes"""
embeddings_file: str = "embeddings.npy"
"""the npz file containing the embeddings of the cards"""
cards_db: str = "../assets/locale/en/cards.cdb"
"""the cards database file"""
batch_size: int = 64
"""the batch size for embedding generation"""
wait_time: float = 0.1
"""the time to wait between each batch"""
def get_embeddings(texts, batch_size=64, wait_time=0.1, verbose=False):
vo = voyageai.Client()
embeddings = []
for i in range(0, len(texts), batch_size):
if verbose:
print(f"Embedding {i} / {len(texts)}")
embeddings += vo.embed(
texts[i : i + batch_size], model="voyage-2", truncation=False).embeddings
time.sleep(wait_time)
embeddings = np.array(embeddings, dtype=np.float32)
return embeddings
def read_decks(d):
# iterate over ydk files
codes = []
for file in os.listdir(d):
if file.endswith(".ydk"):
file = os.path.join(d, file)
codes += load_deck(file)
return set(codes)
def read_texts(cards_db, codes):
df, cards = read_cards(cards_db)
code2card = {c.code: c for c in cards}
texts = []
for code in codes:
texts.append(code2card[code].format())
return texts
if __name__ == "__main__":
args = tyro.cli(Args)
deck_dir = args.deck_dir
code_list_file = args.code_list_file
embeddings_file = args.embeddings_file
cards_db = args.cards_db
# read code_list file
if not os.path.exists(code_list_file):
with open(code_list_file, "w") as f:
f.write("")
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
print(f"The database contains {len(code_list)} cards.")
# read embeddings
if not os.path.exists(embeddings_file):
sample_embedding = get_embeddings(["test"])[0]
all_embeddings = np.zeros((0, len(sample_embedding)), dtype=np.float32)
else:
all_embeddings = np.load(embeddings_file)
print("Embedding dim:", all_embeddings.shape[1])
assert len(all_embeddings) == len(code_list), f"The number of embeddings({len(all_embeddings)}) does not match the number of cards."
all_codes = set(code_list)
new_codes = []
for code in read_decks(deck_dir):
if code not in all_codes:
new_codes.append(code)
if new_codes == []:
print("No new cards have been added to the database.")
exit()
new_texts = read_texts(cards_db, new_codes)
embeddings = get_embeddings(new_texts, args.batch_size, args.wait_time, verbose=True)
# add new embeddings
all_embeddings = np.concatenate([all_embeddings, np.array(embeddings)], axis=0)
# update code_list
code_list += new_codes
# save embeddings and code_list
np.save(embeddings_file, all_embeddings)
with open(code_list_file, "w") as f:
f.write("\n".join(map(str, code_list)) + "\n")
print(f"{len(new_codes)} new cards have been added to the database.")
\ No newline at end of file
import io
import os
from setuptools import find_packages, setup
NAME = 'ygoai'
IMPORT_NAME = 'ygoai'
DESCRIPTION = "A Yu-Gi-Oh! AI."
URL = 'https://github.com/sbl1996/ygo-agent'
EMAIL = 'sbl1996@gmail.com'
AUTHOR = 'Hastur'
REQUIRES_PYTHON = '>=3.8.0'
VERSION = None
REQUIRED = []
here = os.path.dirname(os.path.abspath(__file__))
try:
with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION
about = {}
if not VERSION:
with open(os.path.join(here, IMPORT_NAME, '_version.py')) as f:
exec(f.read(), about)
else:
about['__version__'] = VERSION
setup(
name=NAME,
version=about['__version__'],
description=DESCRIPTION,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
author_email=EMAIL,
python_requires=REQUIRES_PYTHON,
url=URL,
packages=find_packages(include='ygoai*'),
install_requires=REQUIRED,
dependency_links=[],
license='MIT',
)
\ No newline at end of file
add_rules("mode.debug", "mode.release")
add_repositories("my-repo repo")
add_requires(
"ygopro-core", "pybind11 2.10.*", "fmt 10.2.*", "glog 0.6.0",
"concurrentqueue 1.0.4", "sqlitecpp 3.2.1", "unordered_dense 4.4.*")
target("dummy_ygopro")
add_rules("python.library")
add_files("ygoenv/ygoenv/dummy/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue")
set_languages("c++17")
set_policy("build.optimization.lto", true)
add_includedirs("ygoenv")
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/dummy"
os.mv(target:targetfile(), install_target)
print("move target to " .. install_target)
end)
target("ygopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core")
set_languages("c++17")
add_cxxflags("-flto=auto -fno-fat-lto-objects -fvisibility=hidden -march=native")
add_includedirs("ygoenv")
-- for _, header in ipairs(os.files("ygoenv/ygoenv/core/*.h")) do
-- set_pcxxheader(header)
-- end
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro"
os.mv(target:targetfile(), install_target)
print("Move target to " .. install_target)
end)
__version__ = "0.0.1"
\ No newline at end of file
from enum import Flag, auto, unique, IntFlag
__ = lambda s: s
AMOUNT_ATTRIBUTES = 7
AMOUNT_RACES = 25
ATTRIBUTES_OFFSET = 1010
LINK_MARKERS = {
0x001: __("bottom left"),
0x002: __("bottom"),
0x004: __("bottom right"),
0x008: __("left"),
0x020: __("right"),
0x040: __("top left"),
0x080: __("top"),
0x100: __("top right")
}
@unique
class LOCATION(IntFlag):
DECK = 0x1
HAND = 0x2
MZONE = 0x4
SZONE = 0x8
GRAVE = 0x10
REMOVED = 0x20
EXTRA = 0x40
OVERLAY = 0x80
ONFIELD = MZONE | SZONE
FZONE = 0x100
PZONE = 0x200
DECKBOT = 0x10001 # Return to deck bottom
DECKSHF = 0x20001 # Return to deck and shuffle
location2str = {
LOCATION.DECK: 'Deck',
LOCATION.HAND: 'Hand',
LOCATION.MZONE: 'Main Monster Zone',
LOCATION.SZONE: 'Spell & Trap Zone',
LOCATION.GRAVE: 'Graveyard',
LOCATION.REMOVED: 'Banished',
LOCATION.EXTRA: 'Extra Deck',
LOCATION.FZONE: 'Field Zone',
}
all_locations = list(location2str.keys())
PHASES = {
0x01: __('draw phase'),
0x02: __('standby phase'),
0x04: __('main1 phase'),
0x08: __('battle start phase'),
0x10: __('battle step phase'),
0x20: __('damage phase'),
0x40: __('damage calculation phase'),
0x80: __('battle phase'),
0x100: __('main2 phase'),
0x200: __('end phase'),
}
@unique
class POSITION(IntFlag):
FACEUP_ATTACK = 0x1
FACEDOWN_ATTACK = 0x2
FACEUP_DEFENSE = 0x4
FACEUP = FACEUP_ATTACK | FACEUP_DEFENSE
FACEDOWN_DEFENSE = 0x8
FACEDOWN = FACEDOWN_ATTACK | FACEDOWN_DEFENSE
ATTACK = FACEUP_ATTACK | FACEDOWN_ATTACK
DEFENSE = FACEUP_DEFENSE | FACEDOWN_DEFENSE
position2str = {
POSITION.FACEUP_ATTACK: "Face-up Attack",
POSITION.FACEDOWN_ATTACK: "Face-down Attack",
POSITION.FACEUP_DEFENSE: "Face-up Defense",
POSITION.FACEUP: "Face-up",
POSITION.FACEDOWN_DEFENSE: "Face-down Defense",
POSITION.FACEDOWN: "Face-down",
POSITION.ATTACK: "Attack",
POSITION.DEFENSE: "Defense",
}
all_positions = list(position2str.keys())
RACES_OFFSET = 1020
@unique
class QUERY(IntFlag):
CODE = 0x1
POSITION = 0x2
ALIAS = 0x4
TYPE = 0x8
LEVEL = 0x10
RANK = 0x20
ATTRIBUTE = 0x40
RACE = 0x80
ATTACK = 0x100
DEFENSE = 0x200
BASE_ATTACK = 0x400
BASE_DEFENSE = 0x800
REASON = 0x1000
REASON_CARD = 0x2000
EQUIP_CARD = 0x4000
TARGET_CARD = 0x8000
OVERLAY_CARD = 0x10000
COUNTERS = 0x20000
OWNER = 0x40000
STATUS = 0x80000
LSCALE = 0x200000
RSCALE = 0x400000
LINK = 0x800000
@unique
class TYPE(IntFlag):
MONSTER = 0x1
SPELL = 0x2
TRAP = 0x4
NORMAL = 0x10
EFFECT = 0x20
FUSION = 0x40
RITUAL = 0x80
TRAPMONSTER = 0x100
SPIRIT = 0x200
UNION = 0x400
DUAL = 0x800
TUNER = 0x1000
SYNCHRO = 0x2000
TOKEN = 0x4000
QUICKPLAY = 0x10000
CONTINUOUS = 0x20000
EQUIP = 0x40000
FIELD = 0x80000
COUNTER = 0x100000
FLIP = 0x200000
TOON = 0x400000
XYZ = 0x800000
PENDULUM = 0x1000000
SPSUMMON = 0x2000000
LINK = 0x4000000
# for this mud only
EXTRA = XYZ | SYNCHRO | FUSION | LINK
type2str = {
TYPE.MONSTER: "Monster",
TYPE.SPELL: "Spell",
TYPE.TRAP: "Trap",
TYPE.NORMAL: "Normal",
TYPE.EFFECT: "Effect",
TYPE.FUSION: "Fusion",
TYPE.RITUAL: "Ritual",
TYPE.TRAPMONSTER: "Trap Monster",
TYPE.SPIRIT: "Spirit",
TYPE.UNION: "Union",
TYPE.DUAL: "Dual",
TYPE.TUNER: "Tuner",
TYPE.SYNCHRO: "Synchro",
TYPE.TOKEN: "Token",
TYPE.QUICKPLAY: "Quick-play",
TYPE.CONTINUOUS: "Continuous",
TYPE.EQUIP: "Equip",
TYPE.FIELD: "Field",
TYPE.COUNTER: "Counter",
TYPE.FLIP: "Flip",
TYPE.TOON: "Toon",
TYPE.XYZ: "XYZ",
TYPE.PENDULUM: "Pendulum",
TYPE.SPSUMMON: "Special",
TYPE.LINK: "Link"
}
all_types = list(type2str.keys())
@unique
class ATTRIBUTE(IntFlag):
ALL = 0x7f
NONE = 0x0 # Token
EARTH = 0x01
WATER = 0x02
FIRE = 0x04
WIND = 0x08
LIGHT = 0x10
DARK = 0x20
DEVINE = 0x40
attribute2str = {
ATTRIBUTE.ALL: 'All',
ATTRIBUTE.NONE: 'None',
ATTRIBUTE.EARTH: 'Earth',
ATTRIBUTE.WATER: 'Water',
ATTRIBUTE.FIRE: 'Fire',
ATTRIBUTE.WIND: 'Wind',
ATTRIBUTE.LIGHT: 'Light',
ATTRIBUTE.DARK: 'Dark',
ATTRIBUTE.DEVINE: 'Divine'
}
all_attributes = list(attribute2str.keys())
@unique
class RACE(IntFlag):
ALL = 0x3ffffff
NONE = 0x0 # Token
WARRIOR = 0x1
SPELLCASTER = 0x2
FAIRY = 0x4
FIEND = 0x8
ZOMBIE = 0x10
MACHINE = 0x20
AQUA = 0x40
PYRO = 0x80
ROCK = 0x100
WINDBEAST = 0x200
PLANT = 0x400
INSECT = 0x800
THUNDER = 0x1000
DRAGON = 0x2000
BEAST = 0x4000
BEASTWARRIOR = 0x8000
DINOSAUR = 0x10000
FISH = 0x20000
SEASERPENT = 0x40000
REPTILE = 0x80000
PSYCHO = 0x100000
DEVINE = 0x200000
CREATORGOD = 0x400000
WYRM = 0x800000
CYBERSE = 0x1000000
ILLUSION = 0x2000000
race2str = {
RACE.NONE: "None",
RACE.WARRIOR: 'Warrior',
RACE.SPELLCASTER: 'Spellcaster',
RACE.FAIRY: 'Fairy',
RACE.FIEND: 'Fiend',
RACE.ZOMBIE: 'Zombie',
RACE.MACHINE: 'Machine',
RACE.AQUA: 'Aqua',
RACE.PYRO: 'Pyro',
RACE.ROCK: 'Rock',
RACE.WINDBEAST: 'Windbeast',
RACE.PLANT: 'Plant',
RACE.INSECT: 'Insect',
RACE.THUNDER: 'Thunder',
RACE.DRAGON: 'Dragon',
RACE.BEAST: 'Beast',
RACE.BEASTWARRIOR: 'Beast Warrior',
RACE.DINOSAUR: 'Dinosaur',
RACE.FISH: 'Fish',
RACE.SEASERPENT: 'Sea Serpent',
RACE.REPTILE: 'Reptile',
RACE.PSYCHO: 'Psycho',
RACE.DEVINE: 'Divine',
RACE.CREATORGOD: 'Creator God',
RACE.WYRM: 'Wyrm',
RACE.CYBERSE: 'Cyberse',
RACE.ILLUSION: 'Illusion'
}
all_races = list(race2str.keys())
@unique
class REASON(IntFlag):
DESTROY = 0x1
RELEASE = 0x2
TEMPORARY = 0x4
MATERIAL = 0x8
SUMMON = 0x10
BATTLE = 0x20
EFFECT = 0x40
COST = 0x80
ADJUST = 0x100
LOST_TARGET = 0x200
RULE = 0x400
SPSUMMON = 0x800
DISSUMMON = 0x1000
FLIP = 0x2000
DISCARD = 0x4000
RDAMAGE = 0x8000
RRECOVER = 0x10000
RETURN = 0x20000
FUSION = 0x40000
SYNCHRO = 0x80000
RITUAL = 0x100000
XYZ = 0x200000
REPLACE = 0x1000000
DRAW = 0x2000000
REDIRECT = 0x4000000
REVEAL = 0x8000000
LINK = 0x10000000
LOST_OVERLAY = 0x20000000
MAINTENANCE = 0x40000000
ACTION = 0x80000000
PROCEDURE = SYNCHRO | XYZ | LINK
@unique
class OPCODE(IntFlag):
ADD = 0x40000000
SUB = 0x40000001
MUL = 0x40000002
DIV = 0x40000003
AND = 0x40000004
OR = 0x40000005
NEG = 0x40000006
NOT = 0x40000007
ISCODE = 0x40000100
ISSETCARD = 0x40000101
ISTYPE = 0x40000102
ISRACE = 0x40000103
ISATTRIBUTE = 0x40000104
@unique
class INFORM(Flag):
PLAYER = auto()
OPPONENT = auto()
ALL = PLAYER | OPPONENT
@unique
class DECK(Flag):
OWNED = auto()
OTHER = auto()
PUBLIC = auto()
ALL = OWNED | OTHER # should only be used for admins
VISIBLE = OWNED | PUBLIC # default scope for players
\ No newline at end of file
from typing import List, Union
from dataclasses import dataclass
import sqlite3
import pandas as pd
from ygoai.constants import TYPE, type2str, attribute2str, race2str
def parse_types(value):
types = []
all_types = list(type2str.keys())
for t in all_types:
if value & t:
types.append(type2str[t])
return types
def parse_attribute(value):
attribute = attribute2str.get(value, None)
assert attribute, "Invalid attribute, value: " + str(value)
return attribute
def parse_race(value):
race = race2str.get(value, None)
assert race, "Invalid race, value: " + str(value)
return race
@dataclass
class Card:
code: int
name: str
desc: str
types: List[str]
def format(self):
return format_card(self)
@dataclass
class MonsterCard(Card):
atk: int
def_: int
level: int
race: str
attribute: str
@dataclass
class SpellCard(Card):
pass
@dataclass
class TrapCard(Card):
pass
def format_monster_card(card: MonsterCard):
name = card.name
typ = "/".join(card.types)
attribute = card.attribute
race = card.race
level = str(card.level)
atk = str(card.atk)
if atk == '-2':
atk = '?'
def_ = str(card.def_)
if def_ == '-2':
def_ = '?'
if typ == 'Monster/Normal':
desc = "-"
else:
desc = card.desc
columns = [name, typ, attribute, race, level, atk, def_, desc]
return " | ".join(columns)
def format_spell_trap_card(card: Union[SpellCard, TrapCard]):
name = card.name
typ = "/".join(card.types)
desc = card.desc
columns = [name, typ, desc]
return " | ".join(columns)
def format_card(card: Card):
if isinstance(card, MonsterCard):
return format_monster_card(card)
elif isinstance(card, (SpellCard, TrapCard)):
return format_spell_trap_card(card)
else:
raise ValueError("Invalid card type: " + str(card))
## For analyzing cards.db
def parse_monster_card(data) -> MonsterCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
atk = int(data['atk'])
def_ = int(data['def'])
level = int(data['level'])
if level >= 16:
# pendulum monster
level = level % 16
race = parse_race(int(data['race']))
attribute = parse_attribute(int(data['attribute']))
return MonsterCard(code, name, desc, types, atk, def_, level, race, attribute)
def parse_spell_card(data) -> SpellCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
return SpellCard(code, name, desc, types)
def parse_trap_card(data) -> TrapCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
return TrapCard(code, name, desc, types)
def parse_card(data) -> Card:
type_ = data['type']
if type_ & TYPE.MONSTER:
return parse_monster_card(data)
elif type_ & TYPE.SPELL:
return parse_spell_card(data)
elif type_ & TYPE.TRAP:
return parse_trap_card(data)
else:
raise ValueError("Invalid card type: " + str(type_))
def read_cards(cards_path):
conn = sqlite3.connect(cards_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM datas")
datas_rows = cursor.fetchall()
datas_columns = [description[0] for description in cursor.description]
datas_df = pd.DataFrame(datas_rows, columns=datas_columns)
cursor.execute("SELECT * FROM texts")
texts_rows = cursor.fetchall()
texts_columns = [description[0] for description in cursor.description]
texts_df = pd.DataFrame(texts_rows, columns=texts_columns)
cursor.close()
conn.close()
texts_df = texts_df.loc[:, ['id', 'name', 'desc']]
merged_df = pd.merge(texts_df, datas_df, on='id')
cards_data = merged_df.to_dict('records')
cards = [parse_card(data) for data in cards_data]
return merged_df, cards
This diff is collapsed.
This diff is collapsed.
import re
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
def split_param_groups(model, regex):
embed_params = []
other_params = []
for name, param in model.named_parameters():
if re.search(regex, name):
embed_params.append(param)
else:
other_params.append(param)
return [
{'params': embed_params}, {'params': other_params}
]
class Elo:
def __init__(self, k = 10, r0 = 1500, r1 = 1500):
self.r0 = r0
self.r1 = r1
self.k = k
def update(self, winner):
diff = self.k * (1 - self.expect_result(self.r0, self.r1))
if winner == 1:
diff = -diff
self.r0 += diff
self.r1 -= diff
def expect_result(self, p0, p1):
exp = (p0 - p1) / 400.0
return 1 / ((10.0 ** (exp)) + 1)
\ No newline at end of file
import itertools
from pathlib import Path
from ygoenv.ygopro import init_module
def load_deck(fn):
with open(fn) as f:
lines = f.readlines()
noside = itertools.takewhile(lambda x: "side" not in x, lines)
deck = [int(line) for line in noside if line[:-1].isdigit()]
return deck
def get_root_directory():
cur = Path(__file__).resolve()
return str(cur.parent.parent)
def extract_deck_name(path):
return Path(path).stem
_languages = {
"english": "en",
"chinese": "zh",
}
def init_ygopro(lang, deck, code_list_file, preload_tokens=True):
short = _languages[lang]
db_path = Path(get_root_directory(), 'assets', 'locale', short, 'cards.cdb')
deck_fp = Path(deck)
if deck_fp.is_dir():
decks = {f.stem: str(f) for f in deck_fp.glob("*.ydk")}
deck_dir = deck_fp
deck_name = 'random'
else:
deck_name = deck_fp.stem
decks = {deck_name: deck}
deck_dir = deck_fp.parent
if preload_tokens:
token_deck = deck_dir / "_tokens.ydk"
if not token_deck.exists():
raise FileNotFoundError(f"Token deck not found: {token_deck}")
decks["_tokens"] = str(token_deck)
init_module(str(db_path), code_list_file, decks)
return deck_name
\ No newline at end of file
include ygoenv/*/*.so
\ No newline at end of file
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"setuptools",
"wheel",
"numpy",
"dm-env",
"gym>=0.26",
"gymnasium>=0.26,!=0.27.0",
"optree>=0.6.0",
"packaging",
]
setup(
name="ygoenv",
version=__version__,
packages=find_packages(include='ygoenv*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.7",
include_package_data=True,
)
\ No newline at end of file
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool package for efficient RL environment simulation."""
import ygoenv.entry # noqa: F401
from ygoenv.registration import (
list_all_envs,
make,
make_dm,
make_gym,
make_gymnasium,
make_spec,
register,
)
__version__ = "0.8.4"
__all__ = [
"register",
"make",
"make_dm",
"make_gym",
"make_gymnasium",
"make_spec",
"list_all_envs",
]
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif
\ No newline at end of file
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
#define YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
#ifndef MOODYCAMEL_DELETE_FUNCTION
#define MOODYCAMEL_DELETE_FUNCTION = delete
#endif
#include <atomic>
#include <cassert>
#include <utility>
#include <vector>
#include "ygoenv/core/array.h"
#include "concurrentqueue/moodycamel/lightweightsemaphore.h"
/**
* Lock-free action buffer queue.
*/
class ActionBufferQueue {
public:
struct ActionSlice {
int env_id;
int order;
bool force_reset;
};
protected:
std::atomic<uint64_t> alloc_ptr_, done_ptr_;
std::size_t queue_size_;
std::vector<ActionSlice> queue_;
moodycamel::LightweightSemaphore sem_, sem_enqueue_, sem_dequeue_;
public:
explicit ActionBufferQueue(std::size_t num_envs)
: alloc_ptr_(0),
done_ptr_(0),
queue_size_(num_envs * 2),
queue_(queue_size_),
sem_(0),
sem_enqueue_(1),
sem_dequeue_(1) {}
void EnqueueBulk(const std::vector<ActionSlice>& action) {
// ensure only one enqueue_bulk happens at any time
while (!sem_enqueue_.wait()) {
}
uint64_t pos = alloc_ptr_.fetch_add(action.size());
for (std::size_t i = 0; i < action.size(); ++i) {
queue_[(pos + i) % queue_size_] = action[i];
}
sem_.signal(action.size());
sem_enqueue_.signal(1);
}
ActionSlice Dequeue() {
while (!sem_.wait()) {
}
while (!sem_dequeue_.wait()) {
}
auto ptr = done_ptr_.fetch_add(1);
auto ret = queue_[ptr % queue_size_];
sem_dequeue_.signal(1);
return ret;
}
std::size_t SizeApprox() {
return static_cast<std::size_t>(alloc_ptr_ - done_ptr_);
}
};
#endif // YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ARRAY_H_
#define YGOENV_CORE_ARRAY_H_
#include <glog/logging.h>
#include <cstddef>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "ygoenv/core/spec.h"
class Array {
public:
std::size_t size;
std::size_t ndim;
std::size_t element_size;
protected:
std::vector<std::size_t> shape_;
std::shared_ptr<char> ptr_;
template <class Shape, class Deleter>
Array(char* ptr, Shape&& shape, std::size_t element_size, // NOLINT
Deleter&& deleter)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(ptr, std::forward<Deleter>(deleter)) {}
template <class Shape>
Array(std::shared_ptr<char> ptr, Shape&& shape, std::size_t element_size)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(std::move(ptr)) {}
public:
Array() = default;
/**
* Constructor an `Array` of shape defined by `spec`, with `data` as pointer
* to its raw memory. With an empty deleter, which means Array does not own
* the memory.
*/
template <class Deleter>
Array(const ShapeSpec& spec, char* data, Deleter&& deleter) // NOLINT
: Array(data, spec.Shape(), spec.element_size,
std::forward<Deleter>(deleter)) {}
Array(const ShapeSpec& spec, char* data)
: Array(data, spec.Shape(), spec.element_size, [](char* /*unused*/) {}) {}
/**
* Constructor an `Array` of shape defined by `spec`. This constructor
* allocates and owns the memory.
*/
explicit Array(const ShapeSpec& spec)
: Array(spec, nullptr, [](char* /*unused*/) {}) {
ptr_.reset(new char[size * element_size](),
[](const char* p) { delete[] p; });
}
/**
* Take multidimensional index into the Array.
*/
template <typename... Index>
inline Array operator()(Index... index) const {
constexpr std::size_t num_index = sizeof...(Index);
DCHECK_GE(ndim, num_index);
std::size_t offset = 0;
std::size_t i = 0;
for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) {
offset *= shape_[i];
}
return Array(
ptr_.get() + offset * element_size,
std::vector<std::size_t>(shape_.begin() + num_index, shape_.end()),
element_size, [](char* /*unused*/) {});
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline Array operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] Array Slice(std::size_t start, std::size_t end) const {
DCHECK_GT(ndim, (std::size_t)0);
CHECK_GE(shape_[0], end);
CHECK_GE(end, start);
std::vector<std::size_t> new_shape(shape_);
new_shape[0] = end - start;
std::size_t offset = 0;
if (shape_[0] > 0) {
offset = start * size / shape_[0];
}
return {ptr_.get() + offset * element_size, std::move(new_shape),
element_size, [](char* p) {}};
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const Array& value) const {
DCHECK_EQ(element_size, value.element_size)
<< " element size doesn't match";
DCHECK_EQ(size, value.size) << " ndim doesn't match";
std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size);
}
/**
* Assign to this Array a scalar value. This Array needs to have a scalar
* shape.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, Array>, bool> = true>
void operator=(const T& value) const {
DCHECK_EQ(element_size, sizeof(T)) << " element size doesn't match";
DCHECK_EQ(size, (std::size_t)1) << " assigning scalar to non-scalar array";
*reinterpret_cast<T*>(ptr_.get()) = value;
}
/**
* Fills this array with a scalar value of type T.
*/
template <typename T>
void Fill(const T& value) const {
DCHECK_EQ(element_size, sizeof(T)) << " element size doesn't match";
auto* data = reinterpret_cast<T*>(ptr_.get());
std::fill(data, data + size, value);
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
template <typename T>
void Assign(const T* buff, std::size_t sz) const {
DCHECK_EQ(sz, size) << " assignment size mismatch";
DCHECK_EQ(sizeof(T), element_size) << " element size mismatch";
std::memcpy(ptr_.get(), buff, sz * sizeof(T));
}
/**
* Size of axis `dim`.
*/
[[nodiscard]] inline std::size_t Shape(std::size_t dim) const {
return shape_[dim];
}
/**
* Shape
*/
[[nodiscard]] inline const std::vector<std::size_t>& Shape() const {
return shape_;
}
/**
* Pointer to the raw memory.
*/
[[nodiscard]] inline void* Data() const { return ptr_.get(); }
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] Array Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
Array ret(ptr_, std::move(new_shape), element_size);
return ret;
}
void Zero() const { std::memset(ptr_.get(), 0, size * element_size); }
[[nodiscard]] std::shared_ptr<char> SharedPtr() const { return ptr_; }
};
template <typename Dtype>
class TArray : public Array {
template <class Shape, class Deleter>
TArray(char* ptr, Shape&& shape, std::size_t element_size, // NOLINT
Deleter&& deleter)
: Array(ptr, shape, element_size, deleter) {}
template <class Shape>
TArray(const std::shared_ptr<char>& ptr, Shape&& shape,
std::size_t element_size)
: Array(ptr, shape, element_size) {}
public:
TArray() = default;
explicit TArray(const Spec<Dtype>& spec) : Array(spec) {}
explicit TArray(const Spec<Dtype>& spec, const char* data)
: Array(spec, data) {}
template <typename A, std::enable_if_t<std::is_same_v<std::decay_t<A>, Array>,
bool> = true>
explicit TArray(A&& array) : Array(std::forward<A>(array)) { // NOLINT
DCHECK_EQ(array.element_size, sizeof(Dtype));
}
/**
* Take multidimensional index into the Array.
*/
template <typename... Index>
inline TArray operator()(Index... index) const {
return TArray(Array::operator()(index...));
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline TArray operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] TArray Slice(std::size_t start, std::size_t end) const {
return TArray(Array::Slice(start, end));
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const TArray& value) const { Array::Assign(value); }
void Assign(const Array& value) const { Array::Assign(value); }
/**
* Assign a scalar value.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, TArray>, bool> = true>
void operator=(const T& value) const {
*reinterpret_cast<Dtype*>(ptr_.get()) = static_cast<Dtype>(value);
}
/**
* Fills this array with a scalar value of type T.
*/
template <typename T>
void Fill(const T& value) const {
auto data = reinterpret_cast<Dtype*>(ptr_.get());
std::fill(data, data + size, static_cast<Dtype>(value));
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
void Assign(const Dtype* buff, std::size_t sz) const {
std::memcpy(ptr_.get(), buff, sz * sizeof(Dtype));
}
operator Dtype&() const { // NOLINT
return *reinterpret_cast<Dtype*>(ptr_.get());
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, Dtype>, bool> = true>
operator T() const { // NOLINT
DCHECK_EQ(size, (std::size_t)1)
<< " Array with a non-scalar shape can't be used as a scalar";
return static_cast<T>(*reinterpret_cast<Dtype*>(ptr_.get()));
}
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] TArray Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
TArray ret(ptr_, std::move(new_shape), element_size);
return ret;
}
};
#endif // YGOENV_CORE_ARRAY_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ASYNC_ENVPOOL_H_
#define YGOENV_CORE_ASYNC_ENVPOOL_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <thread>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "ygoenv/core/action_buffer_queue.h"
#include "ygoenv/core/array.h"
#include "ygoenv/core/envpool.h"
#include "ygoenv/core/state_buffer_queue.h"
/**
* Async EnvPool
*
* batch-action -> action buffer queue -> threadpool -> state buffer queue
*
* ThreadPool is tailored with EnvPool, so here we don't use the existing
* third_party ThreadPool (which is really slow).
*/
template <typename Env>
class AsyncEnvPool : public EnvPool<typename Env::Spec> {
protected:
std::size_t num_envs_;
std::size_t batch_;
std::size_t max_num_players_;
std::size_t num_threads_;
bool is_sync_;
std::atomic<int> stop_;
std::atomic<std::size_t> stepping_env_num_;
std::vector<std::thread> workers_;
std::unique_ptr<ActionBufferQueue> action_buffer_queue_;
std::unique_ptr<StateBufferQueue> state_buffer_queue_;
std::vector<std::unique_ptr<Env>> envs_;
std::vector<std::atomic<int>> stepping_env_;
std::chrono::duration<double> dur_send_, dur_recv_, dur_send_all_;
template <typename V>
void SendImpl(V&& action) {
int* env_id = static_cast<int*>(action[0].Data());
int shared_offset = action[0].Shape(0);
std::vector<ActionSlice> actions;
std::shared_ptr<std::vector<Array>> action_batch =
std::make_shared<std::vector<Array>>(std::forward<V>(action));
for (int i = 0; i < shared_offset; ++i) {
int eid = env_id[i];
envs_[eid]->SetAction(action_batch, i);
actions.emplace_back(ActionSlice{
.env_id = eid,
.order = is_sync_ ? i : -1,
.force_reset = false,
});
}
if (is_sync_) {
stepping_env_num_ += shared_offset;
}
// add to abq
auto start = std::chrono::system_clock::now();
action_buffer_queue_->EnqueueBulk(actions);
dur_send_ += std::chrono::system_clock::now() - start;
}
public:
using Spec = typename Env::Spec;
using Action = typename Env::Action;
using State = typename Env::State;
using ActionSlice = typename ActionBufferQueue::ActionSlice;
explicit AsyncEnvPool(const Spec& spec)
: EnvPool<Spec>(spec),
num_envs_(spec.config["num_envs"_]),
batch_(spec.config["batch_size"_] <= 0 ? num_envs_
: spec.config["batch_size"_]),
max_num_players_(spec.config["max_num_players"_]),
num_threads_(spec.config["num_threads"_]),
is_sync_(batch_ == num_envs_ && max_num_players_ == 1),
stop_(0),
stepping_env_num_(0),
action_buffer_queue_(new ActionBufferQueue(num_envs_)),
state_buffer_queue_(new StateBufferQueue(
batch_, num_envs_, max_num_players_,
spec.state_spec.template AllValues<ShapeSpec>())),
envs_(num_envs_) {
std::size_t processor_count = std::thread::hardware_concurrency();
ThreadPool init_pool(std::min(processor_count, num_envs_));
std::vector<std::future<void>> result;
for (std::size_t i = 0; i < num_envs_; ++i) {
result.emplace_back(init_pool.enqueue(
[i, spec, this] { envs_[i].reset(new Env(spec, i)); }));
}
for (auto& f : result) {
f.get();
}
if (num_threads_ == 0) {
num_threads_ = std::min(batch_, processor_count);
}
for (std::size_t i = 0; i < num_threads_; ++i) {
workers_.emplace_back([this] {
for (;;) {
ActionSlice raw_action = action_buffer_queue_->Dequeue();
if (stop_ == 1) {
break;
}
int env_id = raw_action.env_id;
int order = raw_action.order;
bool reset = raw_action.force_reset || envs_[env_id]->IsDone();
envs_[env_id]->EnvStep(state_buffer_queue_.get(), order, reset);
}
});
}
if (spec.config["thread_affinity_offset"_] >= 0) {
std::size_t thread_affinity_offset =
spec.config["thread_affinity_offset"_];
for (std::size_t tid = 0; tid < num_threads_; ++tid) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
std::size_t cid = (thread_affinity_offset + tid) % processor_count;
CPU_SET(cid, &cpuset);
pthread_setaffinity_np(workers_[tid].native_handle(), sizeof(cpu_set_t),
&cpuset);
}
}
}
~AsyncEnvPool() override {
stop_ = 1;
// LOG(INFO) << "envpool send: " << dur_send_.count();
// LOG(INFO) << "envpool recv: " << dur_recv_.count();
// send n actions to clear threadpool
std::vector<ActionSlice> empty_actions(workers_.size());
action_buffer_queue_->EnqueueBulk(empty_actions);
for (auto& worker : workers_) {
worker.join();
}
}
void Send(const Action& action) {
SendImpl(action.template AllValues<Array>());
}
void Send(const std::vector<Array>& action) override { SendImpl(action); }
void Send(std::vector<Array>&& action) override { SendImpl(action); }
std::vector<Array> Recv() override {
int additional_wait = 0;
if (is_sync_ && stepping_env_num_ < batch_) {
additional_wait = batch_ - stepping_env_num_;
}
auto start = std::chrono::system_clock::now();
auto ret = state_buffer_queue_->Wait(additional_wait);
dur_recv_ += std::chrono::system_clock::now() - start;
if (is_sync_) {
stepping_env_num_ -= ret[0].Shape(0);
}
return ret;
}
void Reset(const Array& env_ids) override {
TArray<int> tenv_ids(env_ids);
int shared_offset = tenv_ids.Shape(0);
std::vector<ActionSlice> actions(shared_offset);
for (int i = 0; i < shared_offset; ++i) {
actions[i].force_reset = true;
actions[i].env_id = tenv_ids[i];
actions[i].order = is_sync_ ? i : -1;
}
if (is_sync_) {
stepping_env_num_ += shared_offset;
}
action_buffer_queue_->EnqueueBulk(actions);
}
};
#endif // YGOENV_CORE_ASYNC_ENVPOOL_H_
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ENVPOOL_H_
#define YGOENV_CORE_ENVPOOL_H_
#include <utility>
#include <vector>
#include "ygoenv/core/env_spec.h"
/**
* Templated subclass of EnvPool, to be overrided by the real EnvPool.
*/
template <typename EnvSpec>
class EnvPool {
public:
EnvSpec spec;
using Spec = EnvSpec;
using State = NamedVector<typename EnvSpec::StateKeys, std::vector<Array>>;
using Action = NamedVector<typename EnvSpec::ActionKeys, std::vector<Array>>;
explicit EnvPool(EnvSpec spec) : spec(std::move(spec)) {}
virtual ~EnvPool() = default;
protected:
virtual void Send(const std::vector<Array>& action) {
throw std::runtime_error("send not implemented");
}
virtual void Send(std::vector<Array>&& action) {
throw std::runtime_error("send not implemented");
}
virtual std::vector<Array> Recv() {
throw std::runtime_error("recv not implemented");
}
virtual void Reset(const Array& env_ids) {
throw std::runtime_error("reset not implemented");
}
};
#endif // YGOENV_CORE_ENVPOOL_H_
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from ygoenv.registration import register
register(
task_id="YGOPro-v0",
import_path="ygoenv.ygopro",
spec_cls="YGOProEnvSpec",
dm_cls="YGOProDMEnvPool",
gym_cls="YGOProGymEnvPool",
gymnasium_cls="YGOProGymnasiumEnvPool",
)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment