Commit 932b6e39 authored by sbl1996@126.com's avatar sbl1996@126.com

Merge yugioh-ai and envpool2

parents
# Xmake cache
.xmake/
# MacOS Cache
.DS_Store
*.out
code_list.txt
*.npy
.vscode/
checkpoints/
runs/
logs/
k8s_job/
script
assets/locale/*/*.cdb
assets/locale/*/strings.conf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
*.o
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
\ No newline at end of file
yugioh-ai
---------
MIT License
Copyright (c) 2024 Hastur
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Modified from tspivey's yugioh-game <https://github.com/tspivey/yugioh-game>
------------------------------------------------------------------------------
MIT License
Copyright (c) 2017 Tyler Spivey
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Yugioh AI
Yugioh AI uses large language models (LLM) and RL to play Yu-Gi-Oh. It is inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]) and [yugioh-game](https://github.com/tspivey/yugioh-game), and uses [ygopro-core](https://github.com/Fluorohydride/ygopro-core).
## Usage
### Setup
An automated setup script is provided for Linux (tested on Ubuntu 22.04). It will install all dependencies and build the library. To run it, execute the following commands:
```bash
make setup
```
### Running
Test that the repo is setup correctly by running:
```
python cli.py --deck1 deck/Starter.ydk --deck2 deck/BlueEyes.ydk
```
You should see text output showing two random AI playing a duel by making random moves.
You can set `--seed` to a fixed value to get deterministic results.
## Implementation
The implementation is initially based on [yugioh-game](https://github.com/tspivey/yugioh-game). To provide a clean game environment, it removes all server code and only keeps basic duel-related classes and message handlers.
To implement the AI, inspired by [yugioh-ai](https://github.com/melvinzhang/yugioh-ai]), every message handler also provides all possible actions that can be taken in response to the message.
## Notes
Never
\ No newline at end of file
#created by ...
#main
89631139
89631139
89631139
38517737
38517737
38517737
45467446
23434538
23434538
23434538
71039903
71039903
71039903
45644898
45644898
79814787
8240199
8240199
8240199
97268402
97268402
2295440
6853254
6853254
6853254
38120068
38120068
38120068
39701395
39701395
41620959
41620959
41620959
48800175
48800175
48800175
43898403
43898403
63356631
63356631
#extra
40908371
40908371
59822133
59822133
50954680
83994433
33698022
39030163
31801517
02978414
63767246
64332231
10443957
41999284
!side
8233522
8233522
14558127
14558127
14558127
56399890
53129443
25789292
25789292
43455065
43455065
43898403
11109820
11109820
11109820
#created by ...
#main
89631139
55410871
89631139
80701178
31036355
38517737
80701178
80701178
95492061
95492061
95492061
53303460
53303460
53303460
14558127
14558127
23434538
55410871
55410871
31036355
31036355
48800175
48800175
48800175
70368879
70368879
70368879
21082832
46052429
46052429
46052429
24224830
24224830
24224830
73915051
10045474
10045474
37576645
37576645
37576645
#extra
31833038
85289965
74997493
5043010
65330383
38342335
2857636
28776350
75452921
3987233
3987233
99111753
98978921
41999284
41999284
!side
75732622
15397015
15397015
73642296
23434538
5821478
77058170
3679218
25774450
43898403
23002292
23002292
84749824
84749824
#created by wxapp_ygo
#main
25451383
25451383
32731036
35984222
60242223
60242223
48048590
62962630
62962630
62962630
68468459
68468459
68468459
45484331
19096726
95515789
45883110
14558127
14558127
14558127
23434538
23434538
23434538
36577931
25311006
35269904
44362883
06498706
06498706
06498706
01984618
01984618
75500286
81439173
82738008
29948294
36637374
24224830
24224830
18973184
01041278
10045474
10045474
10045474
17751597
17751597
#extra
11321089
03410461
72272462
87746184
87746184
41373230
51409648
01906812
24915933
70534340
44146295
44146295
92892239
38811586
53971455
!side
27204311
82385847
34267821
73642296
27572350
93039339
18144506
08267140
08267140
08267140
15693423
15693423
15693423
51452091
23002292
\ No newline at end of file
#created by wxapp_ygo
#main
63941210
63941210
10604644
46659709
46659709
70095154
70095154
70095154
05370235
23434538
23434538
23434538
23893227
23893227
23893227
01142880
01142880
56364287
56364287
56364287
14532163
14532163
86686671
60600126
60600126
60600126
18144506
12580477
63995093
63995093
03659803
37630732
39973386
84797028
84797028
84797028
64753988
10045474
10045474
10045474
#extra
01546123
87116928
74157028
79229522
79229522
84058253
84058253
22850702
90448279
10443957
73964868
58069384
70369116
46724542
60303245
!side
\ No newline at end of file
#created by ...
#main
27204311
10000080
10000080
10000080
95440946
95440946
95440946
14558127
14558127
14558127
23434538
23434538
23434538
68829754
68829754
94224458
31434645
31434645
31434645
10045474
10045474
20612097
20612097
20612097
58921041
58921041
53334471
53334471
90846359
90846359
23516703
23516703
82732705
67007102
93191801
93191801
20590515
20590515
20590515
56984514
#extra
74889525
62541668
90448279
26556950
26096328
49032236
56910167
56910167
56910167
73082255
37129797
03814632
72860663
!side
\ No newline at end of file
#created by ...
#main
6631034
6631034
6631034
43096270
43096270
43096270
69247929
69247929
69247929
77542832
77542832
77542832
11091375
11091375
11091375
35052053
35052053
35052053
49881766
83104731
83104731
30190809
30190809
26412047
26412047
26412047
43422537
43422537
43422537
53129443
66788016
66788016
66788016
72302403
72302403
44095762
44095762
44095762
70342110
70342110
#extra
!side
#created by ...
#main
55623480
52467217
52467217
52467217
92826944
92826944
92826944
41562624
41562624
99423156
99423156
94801854
94801854
94801854
49959355
49959355
49959355
79783880
14558127
14558127
14558127
36630403
36630403
23434538
23434538
23434538
97268402
12580477
18144507
75500286
81439173
13965201
13965201
24224830
24224830
40364916
40364916
4333086
4333086
10045474
10045474
40605147
40605147
41420027
#extra
59843383
27548199
50954680
83283063
74586817
52711246
57288064
26326541
98558751
86066372
72860663
86926989
37129797
91420202
41999284
!side
#created by wxapp_ygo
#main
27204311
27204311
26077387
26077387
26077387
37351133
14558127
14558127
14558127
23434538
23434538
23434538
20357457
20357457
94145021
94145021
73594093
35726888
35726888
99550630
35261759
70368879
70368879
25311006
32807846
63166095
63166095
63166095
24224830
24224830
52340444
98338152
98338152
98338152
51227866
09726840
09726840
24010609
24010609
50005218
#extra
86066372
75147529
29301450
98462037
63013339
63288573
63288573
63288573
90673288
90673288
90673288
08491308
08491308
08491308
12421694
!side
19613556
12580477
12580477
04031928
04031928
18144506
35269904
08267140
08267140
31849106
31849106
83326048
83326048
15693423
15693423
#created by wxapp_ygo
#main
62849088
05141117
93490856
93490856
93490856
55273560
55273560
55273560
20001443
20001443
20001443
56495147
56495147
56495147
59438930
59438930
14558127
14558127
23434538
23434538
83764718
93850690
56465981
56465981
56465981
18144506
83308376
24224830
24224830
65681983
39730727
39730727
10045474
10045474
23068051
23068051
78836195
14821890
14821890
14821890
99137266
#extra
53971455
86682165
84815190
96633955
47710198
09464441
69248256
69248256
93039339
05402805
78917791
29301450
!side
\ No newline at end of file
#created by ...
#main
27204311
06728559
87052196
87052196
23431858
93490856
93490856
93490856
56495147
56495147
56495147
20001443
20001443
20001443
55273560
55273560
55273560
14558127
14558127
14558127
23434538
23434538
23434538
97268402
97268402
97268402
98159737
35261759
35261759
56465981
56465981
56465981
93850690
24224830
24224830
10045474
10045474
10045474
14821890
14821890
#extra
42632209
60465049
96633955
84815190
47710198
9464441
5041348
69248256
69248256
83755611
43202238
!side
#main
176393
645088
904186
1426715
1799465
2625940
2819436
3285552
4417408
7392746
7610395
8025951
8198621
9047461
9396663
9925983
9929399
10389143
11050416
11050417
11050418
11050419
11654068
11738490
12958920
12965762
13536607
13764603
13935002
14089429
14470846
14470847
14821891
14957441
15341822
15341823
15394084
15590356
15629802
16943771
16946850
17000166
17228909
17418745
18027139
18027140
18027141
18494512
19280590
20001444
20368764
21179144
21770261
21830680
22110648
22404676
22411610
22493812
22953212
23116809
23331401
23837055
24874631
25415053
25419324
26326542
27198002
27204312
27450401
27882994
28053764
28062326
28355719
28674153
29491335
29843092
29843093
29843094
30069399
30327675
30327676
30650148
30765616
30811117
31480216
31533705
31600514
31986289
32056071
32335698
32446631
33676147
34479659
34690954
34767866
34822851
35263181
35268888
35514097
35834120
36629636
38030233
38041941
38053382
39972130
40551411
40633085
40703223
40844553
41329459
41456842
42427231
42671152
42956964
43140792
43664495
44026394
44052075
44092305
44097051
44308318
44330099
44586427
44689689
46173680
46173681
46647145
47658965
48068379
48115278
48411997
49752796
49808197
51208047
51611042
51987572
52340445
52900001
53855410
53855411
54537490
55326323
56051649
56495148
56597273
58371672
59160189
59900656
60025884
60406592
60514626
60764582
60764583
62125439
62481204
62543394
63184228
63442605
64213018
64382840
64583601
65500516
65810490
66200211
66661679
67284108
67489920
67922703
67949764
68815402
69550260
69811711
69868556
69890968
70391589
70465811
70875956
70950699
71645243
72291079
73915052
73915053
73915054
73915055
74440056
74627017
74659583
74983882
75119041
75524093
75622825
75732623
76524507
76589547
77672445
78394033
78789357
78836196
79387393
81767889
82255873
82324106
82340057
82556059
82994510
83239740
84816245
85243785
85771020
85771021
85969518
86801872
86871615
87240372
87669905
88923964
89907228
90884404
91512836
93104633
93130022
93224849
93490857
93912846
94703022
94973029
97452818
98596597
98875864
99092625
99137267
#extra
!side
#created by ...
#main
75498415
81105204
81105204
81105204
58820853
58820853
58820853
49003716
49003716
49003716
14785765
85215458
85215458
85215458
2009101
2009101
2009101
22835145
22835145
22835145
73652465
1475311
1475311
53129443
5318639
5318639
14087893
27243130
27243130
91351370
91351370
91351370
53567095
53567095
53567095
53582587
59839761
72930878
72930878
84749824
#extra
52687916
33236860
16051717
23338098
81983656
69031175
73580471
95040215
76913983
17377751
16195942
86848580
82633039
73347079
78156759
!side
#created by wxapp_ygo
#main
18094166
18094166
18094166
22865492
22865492
27780618
27780618
40044918
40044918
09411399
09411399
89943723
16605586
50720316
59392529
14124483
83965310
23434538
23434538
23434538
14558127
14558127
14558127
94145021
94145021
08949584
08949584
08949584
21143940
21143940
21143940
45906428
52947044
24094653
24094653
24224830
24224830
81439173
32807846
75047173
#extra
30757127
89870349
58481572
58481572
22908820
93347961
46759931
40854197
60461804
56733747
32828466
90590303
58004362
19324993
01948619
!side
94145021
34267821
34267821
35269904
35269904
35269904
14532163
14532163
05758500
18144506
08267140
08267140
08267140
43262273
23002292
\ No newline at end of file
#created by ...
#main
14513273
14513273
14513273
76794549
22211622
96227613
96227613
14920218
69610326
69610326
40318957
40318957
40318957
72714461
72714461
72714461
49684352
49684352
49684352
73941492
27204311
14558127
14558127
14558127
23434538
23434538
23434538
01845204
25311006
41620959
41620959
81439173
24224830
24224830
65681983
82190203
82190203
55795155
74850403
10045474
01344018
#extra
43387895
43387895
53262004
76815942
58074177
84815190
30095833
16691074
20665527
04280258
02772337
92812851
45819647
24094258
22125101
!side
27204311
94145021
94145021
94145021
14532163
25311006
25311006
08267140
08267140
08267140
10045474
10045474
15693423
15693423
15693423
#created by ...
#main
3717252
3717252
3717252
77723643
77723643
30328508
30328508
59546797
97518132
34710660
51023024
51023024
4939890
4939890
4939890
59438930
59438930
69764158
24635329
37445295
37445295
23434538
23434538
1475311
11827244
44394295
44394295
44394295
53129443
81439173
6417578
6417578
48130397
23912837
23912837
77505534
77505534
77505534
4904633
40605147
40605147
84749824
#extra
84433295
74822425
74822425
19261966
20366274
48424886
50907446
50907446
94977269
52687916
73580471
31924889
56832966
84013237
82633039
!side
#created by ...
#main
55623480
52467217
52467217
52467217
92826944
92826944
92826944
41562624
41562624
99423156
99423156
94801854
94801854
94801854
49959355
49959355
49959355
79783880
14558127
14558127
14558127
36630403
36630403
23434538
23434538
23434538
97268402
12580477
18144506
75500286
81439173
13965201
13965201
24224830
24224830
40364916
40364916
4333086
4333086
10045474
10045474
40605147
40605147
41420027
#extra
59843383
27548199
50954680
83283063
74586817
52711246
57288064
26326541
98558751
86066372
72860663
86926989
37129797
91420202
41999284
!side
# Dimensions
B: batch size
C: number of channels
H: number of history actions
# Features
f_cards: (B, n_cards, C), features of cards
f_global: (B, C), global features
f_h_actions: (B, H, C), features of history actions
f_actions: (B, max_actions, C), features of current legal actions
output: (B, max_actions, 1), value of each action
# Fusion
## Method 1
```
f_cards -> n encoder layers -> f_cards
f_global -> ResMLP -> f_global
f_cards = f_cards + f_global
f_actions -> n encoder layers -> f_actions
f_cards[id] -> f_a_cards -> ResMLP -> f_a_cards
f_actions = f_a_cards + f_a_feats
f_actions, f_cards -> n decoder layers -> f_actions
f_h_actions -> n encoder layers -> f_h_actions
f_actions, f_h_actions -> n decoder layers -> f_actions
f_actions -> MLP -> output
```
\ No newline at end of file
# Features
## Card (39)
- id: 2, uint16 -> 2 uint8, name+desc
- location: 1, int, 0: N/A, 1+: same as location2str (9)
- seq: 1, int, 0: N/A, 1+: seq in location
- owner: 1, int, 0: me, 1: oppo (2)
- position: 1, int, 0: N/A, same as position2str
- overlay: 1, int, 0: not, 1: xyz material
- attribute: 1, int, 0: N/A, same as attribute2str[2:]
- race: 1, int, 0: N/A, same as race2str
- level: 1, int, 0: N/A
- atk: 2, max 65535 to 2 bytes
- def: 2, max 65535 to 2 bytes
- type: 25, multi-hot, same as type2str
## Global
- lp: 2, max 65535 to 2 bytes
- oppo_lp: 2, max 65535 to 2 bytes
<!-- - turn: 8, int, trunc to 8 -->
- phase: 1, int, one-hot (10)
- is_first: 1, int, 0: False, 1: True
- is_my_turn: 1, int, 0: False, 1: True
- is_end: 1, int, 0: False, 1: True
## Legal Actions (max 8)
- spec index: 8, int, select target
- msg: 1, int (16)
- act: 1, int (11)
- N/A
- t: Set
- r: Reposition
- c: Special Summon
- s: Summon Face-up Attack
- m: Summon Face-down Defense
- a: Attack
- v: Activate
- v2: Activate the second effect
- v3: Activate the third effect
- v4: Activate the fourth effect
- yes/no: 1, int (3)
- N/A
- Yes
- No
- phase: 1, int (4)
- N/A
- Battle (b)
- Main Phase 2 (m)
- End Phase (e)
- cancel_finish: 1, int (3)
- N/A
- Cancel
- Finish
- position: 1, int , 0: N/A, same as position2str
- option: 1, int, 0: N/A
- place: 1, int (31), 0: N/A,
- 1-7: m
- 8-15: s
- 16-22: om
- 23-30: os
- attribute: 1, int, 0: N/A, same as attribute2id
## History Actions
- id: 2x4, uint16 -> 2 uint8, name+desc
- same as Legal Actions
# Deck
## Unsupported
- Many (Crossout Designator)
- Blackwing (add_counter)
- Magician (pendulum)
- Shaddoll (add_counter)
- Shiranui (Fairy Tail - Snow)
- Hero (random_selected)
# Messgae
## random_selected
Not supported
## add_counter
Not supported
## select_card
- `min` and `max` <= 5 are supported
- `min` > 5 throws an error
- `max` > 5 is truncated to 5
### Unsupported
- Fairy Tail - Snow (min=max=7)
- Pot of Prosperity (min=max=6)
## announce_card
Not supported:
- Alsei, the Sylvan High Protector
- Crossout Designator
## announce_attrib
Only 1 attribute is announced at a time.
Not supported:
- DNA Checkup
# Summon
## Tribute Summon
Through `select_tribute` (multi-select)
## Link Summon
Through `select_unselect_card` (select 1 card per time)
## Syncro Summon
- `select_card` to choose the tuner (usually 1 card)
- `select_sum` to choose the non-tuner (1 card per time)
package("ygopro-core")
set_homepage("https://github.com/Fluorohydride/ygopro-core")
set_urls("https://github.com/Fluorohydride/ygopro-core.git")
add_deps("lua")
on_install("linux", function (package)
io.writefile("xmake.lua", [[
add_rules("mode.debug", "mode.release")
add_requires("lua")
target("ygopro-core")
set_kind("static")
add_files("*.cpp")
add_headerfiles("*.h")
add_packages("lua")
]])
local check_and_insert = function(file, line, insert)
local lines = table.to_array(io.lines(file))
if lines[line] ~= insert then
table.insert(lines, line, insert)
io.writefile(file, table.concat(lines, "\n"))
end
end
check_and_insert("field.h", 14, "#include <cstring>")
check_and_insert("interpreter.h", 11, "extern \"C\" {")
check_and_insert("interpreter.h", 15, "}")
local configs = {}
if package:config("shared") then
configs.kind = "shared"
end
import("package.tools.xmake").install(package)
os.cp("*.h", package:installdir("include", "ygopro-core"))
end)
package_end()
\ No newline at end of file
import os
import random
import time
from typing import Optional
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, Elo
from ygoai.rl.agent import Agent
from ygoai.rl.buffer import DMCDictBuffer
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
play_mode: str = "self"
"""the play mode, can be combination of 'self', 'bot', 'greedy', like 'self+bot'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
total_timesteps: int = 100000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 64
"""the number of parallel game environments"""
num_steps: int = 100
"""the number of steps per env per iteration"""
buffer_size: int = 200000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
minibatch_size: int = 256
"""the mini-batch size"""
eps: float = 0.05
"""the epsilon for exploration"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
log_p: float = 0.1
"""the probability of logging"""
save_freq: int = 100
"""the saving frequency (in terms of iterations)"""
compile: bool = True
"""if toggled, model will be compiled for better performance"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 32
"""the number of threads to use for envpool, defaults to `num_envs`"""
# to be filled in runtime
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = args.num_envs * args.num_steps
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.num_envs,
num_threads=args.env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
envs.num_envs = args.num_envs
obs_space = envs.observation_space
action_space = envs.action_space
print(f"obs_space={obs_space}, action_space={action_space}")
envs = RecordEpisodeStatistics(envs)
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent.load_embeddings(embeddings)
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
avg_win_rates = []
elo = Elo()
selfplay = "self" in args.play_mode
rb = DMCDictBuffer(
args.buffer_size,
obs_space,
action_space,
device=device,
n_envs=args.num_envs,
selfplay=selfplay,
)
gamma = np.float32(args.gamma)
global_step = 0
start_time = time.time()
warmup_steps = 0
obs, infos = envs.reset()
num_options = infos['num_options']
to_play = infos['to_play'] if selfplay else None
for iteration in range(1, args.num_iterations + 1):
agent.eval()
model_time = 0
env_time = 0
buffer_time = 0
collect_start = time.time()
for step in range(args.num_steps):
global_step += args.num_envs
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
if random.random() < args.eps:
actions_ = np.random.randint(num_options)
actions = torch.from_numpy(actions_).to(device)
else:
_start = time.time()
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1)
actions_ = actions.cpu().numpy()
model_time += time.time() - _start
_start = time.time()
next_obs, rewards, dones, infos = envs.step(actions_)
env_time += time.time() - _start
num_options = infos['num_options']
_start = time.time()
rb.add(obs, actions, rewards, to_play)
buffer_time += time.time() - _start
for idx, d in enumerate(dones):
if d:
_start = time.time()
rb.mark_episode(idx, gamma)
buffer_time += time.time() - _start
if random.random() < args.log_p:
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
writer.add_scalar("charts/episodic_return", episode_reward, global_step)
writer.add_scalar("charts/episodic_length", episode_length, global_step)
if selfplay:
if infos['is_selfplay'][idx]:
# win rate for the first player
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
avg_win_rates.append(1 - winner)
else:
# win rate of agent
winner = 0 if episode_reward == 1 else 1
elo.update(winner)
writer.add_scalar("charts/elo_rating", elo.r0, global_step)
else:
winner = 0 if episode_reward == 1 else 1
avg_win_rates.append(1 - winner)
elo.update(winner)
writer.add_scalar("charts/elo_rating", elo.r0, global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}, elo={elo.r0}")
if len(avg_win_rates) > 100:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
avg_win_rates = []
to_play = infos['to_play'] if selfplay else None
obs = next_obs
collect_time = time.time() - collect_start
print(f"global_step={global_step}, collect_time={collect_time}, model_time={model_time}, env_time={env_time}, buffer_time={buffer_time}")
agent.train()
train_start = time.time()
model_time = 0
sample_time = 0
# ALGO LOGIC: training.
_start = time.time()
b_inds = rb.get_data_indices()
if len(b_inds) < args.minibatch_size:
continue
np.random.shuffle(b_inds)
b_obs, b_actions, b_returns = rb._get_samples(b_inds)
sample_time += time.time() - _start
for start in range(0, len(b_inds), args.minibatch_size):
_start = time.time()
end = start + args.minibatch_size
mb_obs = {
k: v[start:end] for k, v in b_obs.items()
}
mb_actions = b_actions[start:end]
mb_returns = b_returns[start:end]
sample_time += time.time() - _start
_start = time.time()
outputs, valid = agent(mb_obs)
outputs = torch.gather(outputs, 1, mb_actions).squeeze(1)
outputs = torch.where(valid, outputs, mb_returns)
loss = F.mse_loss(mb_returns, outputs)
loss = loss * (args.minibatch_size / valid.float().sum())
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
model_time += time.time() - _start
train_time = time.time() - train_start
print(f"global_step={global_step}, train_time={train_time}, model_time={model_time}, sample_time={sample_time}")
writer.add_scalar("losses/value_loss", loss.item(), global_step)
writer.add_scalar("losses/q_values", outputs.mean().item(), global_step)
if not rb.full or iteration % 10 == 0:
torch.cuda.empty_cache()
if iteration == 10:
warmup_steps = global_step
start_time = time.time()
if iteration > 10:
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.save_freq == 0:
save_path = f"checkpoints/agent.pt"
print(f"Saving model to {save_path}")
torch.save(agent.state_dict(), save_path)
envs.close()
writer.close()
\ No newline at end of file
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import Agent
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
player: int = -1
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
bot_type: Literal["random", "greedy"] = "greedy"
"""the type of bot to use"""
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: str = "checkpoints/agent.pt"
"""the checkpoint to load"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
compile: bool = False
"""if toggled, the model will be compiled"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
if __name__ == "__main__":
args = tyro.cli(Args)
if args.play:
args.num_envs = 1
args.verbose = True
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
if args.agent:
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=args.player,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
verbose=args.verbose,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.agent:
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent = agent.eval()
state_dict = torch.load(args.checkpoint, map_location=device)
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
agent.load_state_dict(state_dict)
else:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict)
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), envs.reset()[0])
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
agent = torch.jit.optimize_for_inference(traced_model)
obs, infos = envs.reset()
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
if args.agent:
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1).cpu().numpy()
model_time += time.time() - _start
else:
if args.strategy == "random":
actions = np.random.randint(infos['num_options'])
else:
actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if infos['to_play'][idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward == -1:
win = 0
else:
win = 1
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}")
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
if not args.play:
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
import os
import time
from dataclasses import dataclass
import numpy as np
import voyageai
import tyro
from ygoai.embed import read_cards
from ygoai.utils import load_deck
@dataclass
class Args:
deck_dir: str = "../assets/deck"
"""the directory of ydk files"""
code_list_file: str = "code_list.txt"
"""the file containing the list of card codes"""
embeddings_file: str = "embeddings.npy"
"""the npz file containing the embeddings of the cards"""
cards_db: str = "../assets/locale/en/cards.cdb"
"""the cards database file"""
batch_size: int = 64
"""the batch size for embedding generation"""
wait_time: float = 0.1
"""the time to wait between each batch"""
def get_embeddings(texts, batch_size=64, wait_time=0.1, verbose=False):
vo = voyageai.Client()
embeddings = []
for i in range(0, len(texts), batch_size):
if verbose:
print(f"Embedding {i} / {len(texts)}")
embeddings += vo.embed(
texts[i : i + batch_size], model="voyage-2", truncation=False).embeddings
time.sleep(wait_time)
embeddings = np.array(embeddings, dtype=np.float32)
return embeddings
def read_decks(d):
# iterate over ydk files
codes = []
for file in os.listdir(d):
if file.endswith(".ydk"):
file = os.path.join(d, file)
codes += load_deck(file)
return set(codes)
def read_texts(cards_db, codes):
df, cards = read_cards(cards_db)
code2card = {c.code: c for c in cards}
texts = []
for code in codes:
texts.append(code2card[code].format())
return texts
if __name__ == "__main__":
args = tyro.cli(Args)
deck_dir = args.deck_dir
code_list_file = args.code_list_file
embeddings_file = args.embeddings_file
cards_db = args.cards_db
# read code_list file
if not os.path.exists(code_list_file):
with open(code_list_file, "w") as f:
f.write("")
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
print(f"The database contains {len(code_list)} cards.")
# read embeddings
if not os.path.exists(embeddings_file):
sample_embedding = get_embeddings(["test"])[0]
all_embeddings = np.zeros((0, len(sample_embedding)), dtype=np.float32)
else:
all_embeddings = np.load(embeddings_file)
print("Embedding dim:", all_embeddings.shape[1])
assert len(all_embeddings) == len(code_list), f"The number of embeddings({len(all_embeddings)}) does not match the number of cards."
all_codes = set(code_list)
new_codes = []
for code in read_decks(deck_dir):
if code not in all_codes:
new_codes.append(code)
if new_codes == []:
print("No new cards have been added to the database.")
exit()
new_texts = read_texts(cards_db, new_codes)
embeddings = get_embeddings(new_texts, args.batch_size, args.wait_time, verbose=True)
# add new embeddings
all_embeddings = np.concatenate([all_embeddings, np.array(embeddings)], axis=0)
# update code_list
code_list += new_codes
# save embeddings and code_list
np.save(embeddings_file, all_embeddings)
with open(code_list_file, "w") as f:
f.write("\n".join(map(str, code_list)) + "\n")
print(f"{len(new_codes)} new cards have been added to the database.")
\ No newline at end of file
import io
import os
from setuptools import find_packages, setup
NAME = 'ygoai'
IMPORT_NAME = 'ygoai'
DESCRIPTION = "A Yu-Gi-Oh! AI."
URL = 'https://github.com/sbl1996/ygo-agent'
EMAIL = 'sbl1996@gmail.com'
AUTHOR = 'Hastur'
REQUIRES_PYTHON = '>=3.8.0'
VERSION = None
REQUIRED = []
here = os.path.dirname(os.path.abspath(__file__))
try:
with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION
about = {}
if not VERSION:
with open(os.path.join(here, IMPORT_NAME, '_version.py')) as f:
exec(f.read(), about)
else:
about['__version__'] = VERSION
setup(
name=NAME,
version=about['__version__'],
description=DESCRIPTION,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
author_email=EMAIL,
python_requires=REQUIRES_PYTHON,
url=URL,
packages=find_packages(include='ygoai*'),
install_requires=REQUIRED,
dependency_links=[],
license='MIT',
)
\ No newline at end of file
add_rules("mode.debug", "mode.release")
add_repositories("my-repo repo")
add_requires(
"ygopro-core", "pybind11 2.10.*", "fmt 10.2.*", "glog 0.6.0",
"concurrentqueue 1.0.4", "sqlitecpp 3.2.1", "unordered_dense 4.4.*")
target("dummy_ygopro")
add_rules("python.library")
add_files("ygoenv/ygoenv/dummy/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue")
set_languages("c++17")
set_policy("build.optimization.lto", true)
add_includedirs("ygoenv")
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/dummy"
os.mv(target:targetfile(), install_target)
print("move target to " .. install_target)
end)
target("ygopro_ygoenv")
add_rules("python.library")
add_files("ygoenv/ygoenv/ygopro/*.cpp")
add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core")
set_languages("c++17")
add_cxxflags("-flto=auto -fno-fat-lto-objects -fvisibility=hidden -march=native")
add_includedirs("ygoenv")
-- for _, header in ipairs(os.files("ygoenv/ygoenv/core/*.h")) do
-- set_pcxxheader(header)
-- end
after_build(function (target)
local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro"
os.mv(target:targetfile(), install_target)
print("Move target to " .. install_target)
end)
__version__ = "0.0.1"
\ No newline at end of file
from enum import Flag, auto, unique, IntFlag
__ = lambda s: s
AMOUNT_ATTRIBUTES = 7
AMOUNT_RACES = 25
ATTRIBUTES_OFFSET = 1010
LINK_MARKERS = {
0x001: __("bottom left"),
0x002: __("bottom"),
0x004: __("bottom right"),
0x008: __("left"),
0x020: __("right"),
0x040: __("top left"),
0x080: __("top"),
0x100: __("top right")
}
@unique
class LOCATION(IntFlag):
DECK = 0x1
HAND = 0x2
MZONE = 0x4
SZONE = 0x8
GRAVE = 0x10
REMOVED = 0x20
EXTRA = 0x40
OVERLAY = 0x80
ONFIELD = MZONE | SZONE
FZONE = 0x100
PZONE = 0x200
DECKBOT = 0x10001 # Return to deck bottom
DECKSHF = 0x20001 # Return to deck and shuffle
location2str = {
LOCATION.DECK: 'Deck',
LOCATION.HAND: 'Hand',
LOCATION.MZONE: 'Main Monster Zone',
LOCATION.SZONE: 'Spell & Trap Zone',
LOCATION.GRAVE: 'Graveyard',
LOCATION.REMOVED: 'Banished',
LOCATION.EXTRA: 'Extra Deck',
LOCATION.FZONE: 'Field Zone',
}
all_locations = list(location2str.keys())
PHASES = {
0x01: __('draw phase'),
0x02: __('standby phase'),
0x04: __('main1 phase'),
0x08: __('battle start phase'),
0x10: __('battle step phase'),
0x20: __('damage phase'),
0x40: __('damage calculation phase'),
0x80: __('battle phase'),
0x100: __('main2 phase'),
0x200: __('end phase'),
}
@unique
class POSITION(IntFlag):
FACEUP_ATTACK = 0x1
FACEDOWN_ATTACK = 0x2
FACEUP_DEFENSE = 0x4
FACEUP = FACEUP_ATTACK | FACEUP_DEFENSE
FACEDOWN_DEFENSE = 0x8
FACEDOWN = FACEDOWN_ATTACK | FACEDOWN_DEFENSE
ATTACK = FACEUP_ATTACK | FACEDOWN_ATTACK
DEFENSE = FACEUP_DEFENSE | FACEDOWN_DEFENSE
position2str = {
POSITION.FACEUP_ATTACK: "Face-up Attack",
POSITION.FACEDOWN_ATTACK: "Face-down Attack",
POSITION.FACEUP_DEFENSE: "Face-up Defense",
POSITION.FACEUP: "Face-up",
POSITION.FACEDOWN_DEFENSE: "Face-down Defense",
POSITION.FACEDOWN: "Face-down",
POSITION.ATTACK: "Attack",
POSITION.DEFENSE: "Defense",
}
all_positions = list(position2str.keys())
RACES_OFFSET = 1020
@unique
class QUERY(IntFlag):
CODE = 0x1
POSITION = 0x2
ALIAS = 0x4
TYPE = 0x8
LEVEL = 0x10
RANK = 0x20
ATTRIBUTE = 0x40
RACE = 0x80
ATTACK = 0x100
DEFENSE = 0x200
BASE_ATTACK = 0x400
BASE_DEFENSE = 0x800
REASON = 0x1000
REASON_CARD = 0x2000
EQUIP_CARD = 0x4000
TARGET_CARD = 0x8000
OVERLAY_CARD = 0x10000
COUNTERS = 0x20000
OWNER = 0x40000
STATUS = 0x80000
LSCALE = 0x200000
RSCALE = 0x400000
LINK = 0x800000
@unique
class TYPE(IntFlag):
MONSTER = 0x1
SPELL = 0x2
TRAP = 0x4
NORMAL = 0x10
EFFECT = 0x20
FUSION = 0x40
RITUAL = 0x80
TRAPMONSTER = 0x100
SPIRIT = 0x200
UNION = 0x400
DUAL = 0x800
TUNER = 0x1000
SYNCHRO = 0x2000
TOKEN = 0x4000
QUICKPLAY = 0x10000
CONTINUOUS = 0x20000
EQUIP = 0x40000
FIELD = 0x80000
COUNTER = 0x100000
FLIP = 0x200000
TOON = 0x400000
XYZ = 0x800000
PENDULUM = 0x1000000
SPSUMMON = 0x2000000
LINK = 0x4000000
# for this mud only
EXTRA = XYZ | SYNCHRO | FUSION | LINK
type2str = {
TYPE.MONSTER: "Monster",
TYPE.SPELL: "Spell",
TYPE.TRAP: "Trap",
TYPE.NORMAL: "Normal",
TYPE.EFFECT: "Effect",
TYPE.FUSION: "Fusion",
TYPE.RITUAL: "Ritual",
TYPE.TRAPMONSTER: "Trap Monster",
TYPE.SPIRIT: "Spirit",
TYPE.UNION: "Union",
TYPE.DUAL: "Dual",
TYPE.TUNER: "Tuner",
TYPE.SYNCHRO: "Synchro",
TYPE.TOKEN: "Token",
TYPE.QUICKPLAY: "Quick-play",
TYPE.CONTINUOUS: "Continuous",
TYPE.EQUIP: "Equip",
TYPE.FIELD: "Field",
TYPE.COUNTER: "Counter",
TYPE.FLIP: "Flip",
TYPE.TOON: "Toon",
TYPE.XYZ: "XYZ",
TYPE.PENDULUM: "Pendulum",
TYPE.SPSUMMON: "Special",
TYPE.LINK: "Link"
}
all_types = list(type2str.keys())
@unique
class ATTRIBUTE(IntFlag):
ALL = 0x7f
NONE = 0x0 # Token
EARTH = 0x01
WATER = 0x02
FIRE = 0x04
WIND = 0x08
LIGHT = 0x10
DARK = 0x20
DEVINE = 0x40
attribute2str = {
ATTRIBUTE.ALL: 'All',
ATTRIBUTE.NONE: 'None',
ATTRIBUTE.EARTH: 'Earth',
ATTRIBUTE.WATER: 'Water',
ATTRIBUTE.FIRE: 'Fire',
ATTRIBUTE.WIND: 'Wind',
ATTRIBUTE.LIGHT: 'Light',
ATTRIBUTE.DARK: 'Dark',
ATTRIBUTE.DEVINE: 'Divine'
}
all_attributes = list(attribute2str.keys())
@unique
class RACE(IntFlag):
ALL = 0x3ffffff
NONE = 0x0 # Token
WARRIOR = 0x1
SPELLCASTER = 0x2
FAIRY = 0x4
FIEND = 0x8
ZOMBIE = 0x10
MACHINE = 0x20
AQUA = 0x40
PYRO = 0x80
ROCK = 0x100
WINDBEAST = 0x200
PLANT = 0x400
INSECT = 0x800
THUNDER = 0x1000
DRAGON = 0x2000
BEAST = 0x4000
BEASTWARRIOR = 0x8000
DINOSAUR = 0x10000
FISH = 0x20000
SEASERPENT = 0x40000
REPTILE = 0x80000
PSYCHO = 0x100000
DEVINE = 0x200000
CREATORGOD = 0x400000
WYRM = 0x800000
CYBERSE = 0x1000000
ILLUSION = 0x2000000
race2str = {
RACE.NONE: "None",
RACE.WARRIOR: 'Warrior',
RACE.SPELLCASTER: 'Spellcaster',
RACE.FAIRY: 'Fairy',
RACE.FIEND: 'Fiend',
RACE.ZOMBIE: 'Zombie',
RACE.MACHINE: 'Machine',
RACE.AQUA: 'Aqua',
RACE.PYRO: 'Pyro',
RACE.ROCK: 'Rock',
RACE.WINDBEAST: 'Windbeast',
RACE.PLANT: 'Plant',
RACE.INSECT: 'Insect',
RACE.THUNDER: 'Thunder',
RACE.DRAGON: 'Dragon',
RACE.BEAST: 'Beast',
RACE.BEASTWARRIOR: 'Beast Warrior',
RACE.DINOSAUR: 'Dinosaur',
RACE.FISH: 'Fish',
RACE.SEASERPENT: 'Sea Serpent',
RACE.REPTILE: 'Reptile',
RACE.PSYCHO: 'Psycho',
RACE.DEVINE: 'Divine',
RACE.CREATORGOD: 'Creator God',
RACE.WYRM: 'Wyrm',
RACE.CYBERSE: 'Cyberse',
RACE.ILLUSION: 'Illusion'
}
all_races = list(race2str.keys())
@unique
class REASON(IntFlag):
DESTROY = 0x1
RELEASE = 0x2
TEMPORARY = 0x4
MATERIAL = 0x8
SUMMON = 0x10
BATTLE = 0x20
EFFECT = 0x40
COST = 0x80
ADJUST = 0x100
LOST_TARGET = 0x200
RULE = 0x400
SPSUMMON = 0x800
DISSUMMON = 0x1000
FLIP = 0x2000
DISCARD = 0x4000
RDAMAGE = 0x8000
RRECOVER = 0x10000
RETURN = 0x20000
FUSION = 0x40000
SYNCHRO = 0x80000
RITUAL = 0x100000
XYZ = 0x200000
REPLACE = 0x1000000
DRAW = 0x2000000
REDIRECT = 0x4000000
REVEAL = 0x8000000
LINK = 0x10000000
LOST_OVERLAY = 0x20000000
MAINTENANCE = 0x40000000
ACTION = 0x80000000
PROCEDURE = SYNCHRO | XYZ | LINK
@unique
class OPCODE(IntFlag):
ADD = 0x40000000
SUB = 0x40000001
MUL = 0x40000002
DIV = 0x40000003
AND = 0x40000004
OR = 0x40000005
NEG = 0x40000006
NOT = 0x40000007
ISCODE = 0x40000100
ISSETCARD = 0x40000101
ISTYPE = 0x40000102
ISRACE = 0x40000103
ISATTRIBUTE = 0x40000104
@unique
class INFORM(Flag):
PLAYER = auto()
OPPONENT = auto()
ALL = PLAYER | OPPONENT
@unique
class DECK(Flag):
OWNED = auto()
OTHER = auto()
PUBLIC = auto()
ALL = OWNED | OTHER # should only be used for admins
VISIBLE = OWNED | PUBLIC # default scope for players
\ No newline at end of file
from typing import List, Union
from dataclasses import dataclass
import sqlite3
import pandas as pd
from ygoai.constants import TYPE, type2str, attribute2str, race2str
def parse_types(value):
types = []
all_types = list(type2str.keys())
for t in all_types:
if value & t:
types.append(type2str[t])
return types
def parse_attribute(value):
attribute = attribute2str.get(value, None)
assert attribute, "Invalid attribute, value: " + str(value)
return attribute
def parse_race(value):
race = race2str.get(value, None)
assert race, "Invalid race, value: " + str(value)
return race
@dataclass
class Card:
code: int
name: str
desc: str
types: List[str]
def format(self):
return format_card(self)
@dataclass
class MonsterCard(Card):
atk: int
def_: int
level: int
race: str
attribute: str
@dataclass
class SpellCard(Card):
pass
@dataclass
class TrapCard(Card):
pass
def format_monster_card(card: MonsterCard):
name = card.name
typ = "/".join(card.types)
attribute = card.attribute
race = card.race
level = str(card.level)
atk = str(card.atk)
if atk == '-2':
atk = '?'
def_ = str(card.def_)
if def_ == '-2':
def_ = '?'
if typ == 'Monster/Normal':
desc = "-"
else:
desc = card.desc
columns = [name, typ, attribute, race, level, atk, def_, desc]
return " | ".join(columns)
def format_spell_trap_card(card: Union[SpellCard, TrapCard]):
name = card.name
typ = "/".join(card.types)
desc = card.desc
columns = [name, typ, desc]
return " | ".join(columns)
def format_card(card: Card):
if isinstance(card, MonsterCard):
return format_monster_card(card)
elif isinstance(card, (SpellCard, TrapCard)):
return format_spell_trap_card(card)
else:
raise ValueError("Invalid card type: " + str(card))
## For analyzing cards.db
def parse_monster_card(data) -> MonsterCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
atk = int(data['atk'])
def_ = int(data['def'])
level = int(data['level'])
if level >= 16:
# pendulum monster
level = level % 16
race = parse_race(int(data['race']))
attribute = parse_attribute(int(data['attribute']))
return MonsterCard(code, name, desc, types, atk, def_, level, race, attribute)
def parse_spell_card(data) -> SpellCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
return SpellCard(code, name, desc, types)
def parse_trap_card(data) -> TrapCard:
code = int(data['id'])
name = data['name']
desc = data['desc']
types = parse_types(int(data['type']))
return TrapCard(code, name, desc, types)
def parse_card(data) -> Card:
type_ = data['type']
if type_ & TYPE.MONSTER:
return parse_monster_card(data)
elif type_ & TYPE.SPELL:
return parse_spell_card(data)
elif type_ & TYPE.TRAP:
return parse_trap_card(data)
else:
raise ValueError("Invalid card type: " + str(type_))
def read_cards(cards_path):
conn = sqlite3.connect(cards_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM datas")
datas_rows = cursor.fetchall()
datas_columns = [description[0] for description in cursor.description]
datas_df = pd.DataFrame(datas_rows, columns=datas_columns)
cursor.execute("SELECT * FROM texts")
texts_rows = cursor.fetchall()
texts_columns = [description[0] for description in cursor.description]
texts_df = pd.DataFrame(texts_rows, columns=texts_columns)
cursor.close()
conn.close()
texts_df = texts_df.loc[:, ['id', 'name', 'desc']]
merged_df = pd.merge(texts_df, datas_df, on='id')
cards_data = merged_df.to_dict('records')
cards = [parse_card(data) for data in cards_data]
return merged_df, cards
import torch
import torch.nn as nn
def bytes_to_bin(x, points, intervals):
x = x[..., 0] * 256 + x[..., 1]
x = x.unsqueeze(-1)
return torch.clamp((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = torch.linspace(0, x_max1, sig_bins + 1, dtype=torch.float32)[1:]
points2 = torch.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=torch.float32)[1:]
points = torch.cat([points1, points2], dim=0)
intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0)
return points, intervals
class Agent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Agent, self).__init__()
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.seq_embed = nn.Embedding(41, c)
self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine)
linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias)
c_num = c // 8
n_bins = 32
self.num_fc = nn.Sequential(
linear(n_bins, c_num),
nn.ReLU(),
)
bin_points, bin_intervals = make_bin_params(n_bins=n_bins)
self.bin_points = nn.Parameter(bin_points, requires_grad=False)
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None:
n_embed, embed_dim = 150, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
self.id_embed = nn.Embedding(n_embed, embed_dim)
self.id_fc_emb = linear(1024, c // 4)
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
self.owner_embed = nn.Embedding(2, c // 16 * 2)
self.position_embed = nn.Embedding(9, c // 16 * 2)
self.overley_embed = nn.Embedding(2, c // 16)
self.attribute_embed = nn.Embedding(8, c // 16)
self.race_embed = nn.Embedding(27, c // 16)
self.level_embed = nn.Embedding(14, c // 16)
self.type_fc_emb = linear(25, c // 16 * 2)
self.atk_fc_emb = linear(c_num, c // 16)
self.def_fc_emb = linear(c_num, c // 16)
self.feat_norm = nn.LayerNorm(c // 4 * 3, elementwise_affine=affine)
self.na_card_embed = nn.Parameter(torch.randn(1, c) * 0.02, requires_grad=True)
num_heads = max(2, c // 128)
self.card_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_card_layers)
])
self.card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.phase_embed = nn.Embedding(10, c // 4)
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.global_norm_pre = nn.LayerNorm(c, elementwise_affine=affine)
self.global_net = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.global_norm = nn.LayerNorm(c, elementwise_affine=False)
divisor = 8
self.a_msg_embed = nn.Embedding(30, c // divisor)
self.a_act_embed = nn.Embedding(11, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(5, c // divisor)
self.a_option_embed = nn.Embedding(4, c // divisor)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
self.a_attrib_embed = nn.Embedding(31, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.a_card_proj = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
num_heads = max(2, c // 128)
self.action_card_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.init_embeddings()
def init_embeddings(self, scale=0.0001):
for n, m in self.named_modules():
if isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -scale, scale)
elif n in ["atk_fc_emb", "def_fc_emb"]:
nn.init.uniform_(m.weight, -scale * 10, scale * 10)
elif n in ["lp_fc_emb", "oppo_lp_fc_emb"]:
nn.init.uniform_(m.weight, -scale, scale)
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings, freeze=True):
weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings)
if freeze:
weight.requires_grad = False
def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
def encode_action_(self, x):
x_a_msg = self.a_msg_embed(x[:, :, 0])
x_a_act = self.a_act_embed(x[:, :, 1])
x_a_yesno = self.a_yesno_embed(x[:, :, 2])
x_a_phase = self.a_phase_embed(x[:, :, 3])
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_place = self.a_place_embed(x[:, :, 7])
x_a_attrib = self.a_attrib_embed(x[:, :, 8])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
m = c // 2
spec_index = x.view(b, n, m, 2)
spec_index = spec_index[..., 0] * 256 + spec_index[..., 1]
mask = spec_index != 0
mask[:, :, 0] = True
spec_index = spec_index.view(b, -1)
B = torch.arange(b, device=spec_index.device)
f_a_actions = f_cards[B[:, None], spec_index]
f_a_actions = f_a_actions.view(b, n, m, -1)
f_a_actions = (f_a_actions * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return f_a_actions
def get_h_action_card_(self, x):
b, n, _ = x.shape
x_ids = x.view(b, n, -1, 2)
x_ids = x_ids[..., 0] * 256 + x_ids[..., 1]
mask = x_ids != 0
mask[:, :, 0] = True
x_ids = self.id_embed(x_ids)
x_ids = self.h_id_fc_emb(x_ids)
x_ids = (x_ids * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return x_ids
def encode_card_id(self, x):
x_id = self.id_embed(x)
x_id = self.id_fc_emb(x_id)
x_id = self.id_norm(x_id)
return x_id
def encode_card_feat1(self, x1):
x_owner = self.owner_embed(x1[:, :, 2])
x_position = self.position_embed(x1[:, :, 3])
x_overley = self.overley_embed(x1[:, :, 4])
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level
def encode_card_feat2(self, x2):
x_atk = self.num_transform(x2[:, :, 0:2])
x_atk = self.atk_fc_emb(x_atk)
x_def = self.num_transform(x2[:, :, 2:4])
x_def = self.def_fc_emb(x_def)
x_type = self.type_fc_emb(x2[:, :, 4:])
return x_atk, x_def, x_type
def encode_global(self, x):
x_global_1 = x[:, :4].float()
x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:-1].long()
x_g_phase = self.phase_embed(x_global_2[:, 0])
x_g_if_first = self.if_first_embed(x_global_2[:, 1])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 2])
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1)
return x_global
def forward(self, x):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, 2:10].long()
x_cards_2 = x_cards[:, :, 10:].to(torch.float32)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 1]))
x_feat1 = self.encode_card_feat1(x_cards_1)
x_feat2 = self.encode_card_feat2(x_cards_2)
x_feat = torch.cat([*x_feat1, *x_feat2], dim=-1)
x_feat = self.feat_norm(x_feat)
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
x_global = self.global_norm_pre(x_global)
f_global = x_global + self.global_net(x_global)
f_global = self.global_norm(f_global)
f_cards = f_cards + f_global.unsqueeze(1)
x_actions = x_actions.long()
max_multi_select = (x_actions.shape[-1] - 9) // 2
mo = max_multi_select * 2
f_a_cards = self.get_action_card_(x_actions[..., :mo], f_cards)
f_a_cards = f_a_cards + self.a_card_proj(self.a_card_norm(f_a_cards))
x_a_feats = self.encode_action_(x_actions[..., mo:])
x_a_feats = torch.cat(x_a_feats, dim=-1)
f_actions = f_a_cards + self.a_feat_norm(x_a_feats)
mask = x_actions[:, :, mo] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :mo])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions)
values = self.value_head(f_actions)[..., 0]
values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -1.01), values)
return values, valid
\ No newline at end of file
from abc import ABC, abstractmethod
import warnings
from typing import Dict, Tuple, Union, NamedTuple, List, Any, Optional
import numpy as np
import numba
import torch as th
from gymnasium import spaces
import psutil
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_action_dim(action_space: spaces.Space) -> int:
"""
Get the dimension of the action space.
:param action_space:
:return:
"""
if isinstance(action_space, spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
elif isinstance(action_space, spaces.MultiDiscrete):
# Number of discrete actions
return int(len(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
def get_obs_shape(
observation_space: spaces.Space,
) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return observation_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]
else:
raise NotImplementedError(f"{observation_space} observation space is not supported")
TensorDict = Dict[str, th.Tensor]
class DictReplayBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
dones: th.Tensor
rewards: th.Tensor
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
to which the values will be converted
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Space
obs_shape: Tuple[int, ...]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
super().__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:return:
"""
shape = arr.shape
if len(shape) < 3:
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(self, *args, **kwargs) -> None:
"""
Add elements to the buffer.
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def sample(self, batch_size: int):
"""
:param batch_size: Number of element to sample
:return:
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds)
@abstractmethod
def _get_samples(self, batch_inds: np.ndarray):
"""
:param batch_inds:
:param env:
:return:
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
:return:
"""
if copy:
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
class ReplayBuffer(BaseBuffer):
"""
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
"""
observations: np.ndarray
next_observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
dones: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
if not optimize_memory_usage:
# When optimizing memory, `observations` contains also the next observation
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
total_memory_usage: float = (
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
)
if not optimize_memory_usage:
total_memory_usage += self.next_observations.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs)
if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
else:
self.next_observations[self.pos] = np.array(next_obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int):
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:return:
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size)
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds)
def _get_samples(self, batch_inds: np.ndarray):
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
if self.optimize_memory_usage:
next_obs = self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :]
else:
next_obs = self.next_observations[batch_inds, env_indices, :]
data = (
self.observations[batch_inds, env_indices, :],
self.actions[batch_inds, env_indices, :],
next_obs,
self.dones[batch_inds, env_indices].reshape(-1, 1),
self.rewards[batch_inds, env_indices].reshape(-1, 1),
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype
dtype_dict = {
np.bool_ : th.bool,
np.uint8 : th.uint8,
np.int8 : th.int8,
np.int16 : th.int16,
np.int32 : th.int32,
np.int64 : th.int64,
np.float16 : th.float16,
np.float32 : th.float32,
np.float64 : th.float64,
}
@numba.njit
def nstep_return_selfplay(rewards, to_play, gamma):
returns = np.zeros_like(rewards)
R0 = rewards[-1]
R1 = -rewards[-1]
returns[-1] = R0
pl = to_play[-1]
for step in np.arange(len(rewards) - 2, -1, -1):
if to_play[step] == pl:
R0 = gamma * R0 + rewards[step]
returns[step] = R0
else:
R1 = gamma * R1 - rewards[step]
returns[step] = R1
return returns
@numba.njit
def nstep_return(rewards, gamma):
returns = np.zeros_like(rewards)
R = 0.0
for step in np.arange(len(rewards) - 1, -1, -1):
R = rewards[step] + gamma * R
returns[step] = R
return returns
class DMCBuffer:
observations: np.ndarray
actions: np.ndarray
returns: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
n_envs: int = 1,
device: Union[th.device, str] = "auto",
):
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.action_dim = get_action_dim(action_space)
self.n_envs = n_envs
self.pos = 0
self.full = False
self.start = np.zeros((self.n_envs,), dtype=np.int32)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
self.device = get_device(device)
self.observations = th.zeros(
(self.buffer_size, self.n_envs, *self.obs_shape), dtype=dtype_dict[observation_space.dtype.type], device=self.device)
self.actions = th.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=dtype_dict[action_space.dtype.type], device=self.device)
self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(
self,
obs: th.Tensor,
action: th.Tensor,
reward: np.ndarray,
) -> None:
batch_size = reward.shape[0]
self.observations[self.pos] = obs
self.actions[self.pos] = action.reshape((batch_size, self.action_dim))
self.rewards[self.pos] = np.array(reward)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def mark_episode(self, env_ind, gamma):
start = self.start[env_ind]
pos = self.pos
if pos <= start:
end = pos + self.buffer_size
batch_inds = np.arange(start, end) % self.buffer_size
else:
batch_inds = np.arange(start, pos)
self.start[env_ind] = pos
returns = nstep_return(self.rewards[batch_inds, env_ind], gamma)
self.returns[batch_inds, env_ind] = th.from_numpy(returns).to(self.device)
def get_data_indices(self):
if not self.full:
indices = np.arange(self.start.min())
else:
# if np.all(pos >= self.start):
# indices = np.arange(pos, self.start.min() + self.buffer_size) % self.buffer_size
# elif np.all(pos < self.start):
# indices = np.arange(pos, self.start.min())
# else:
start = self.pos
end = np.where(self.pos >= self.start, self.start + self.buffer_size, self.start).min()
indices = np.arange(start, end) % self.buffer_size
return indices
def _get_samples(self, batch_inds: np.ndarray):
data = (
self.observations[batch_inds, :, :].reshape(-1, *self.obs_shape),
self.actions[batch_inds, :, :].reshape(-1, self.action_dim),
self.returns[batch_inds, :].reshape(-1),
)
return data
class DMCDictBuffer:
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
selfplay: bool = False,
):
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.selfplay = selfplay
self.action_dim = get_action_dim(action_space)
self.n_envs = n_envs
self.pos = 0
self.full = False
self.device = get_device(device)
self.start = np.zeros((self.n_envs,), dtype=np.int32)
self.buffer_size = max(buffer_size // n_envs, 1)
self.observations = {
key: th.zeros(
(self.buffer_size, self.n_envs, *_obs_shape),
dtype=dtype_dict[observation_space[key].dtype.type], device=self.device)
for key, _obs_shape in self.obs_shape.items()
}
self.actions = th.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=dtype_dict[action_space.dtype.type], device=self.device)
self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if self.selfplay:
self.to_play = np.zeros((self.buffer_size, self.n_envs), dtype=np.int32)
self._observations = {
key: th.zeros(
(self.buffer_size * self.n_envs, *_obs_shape),
dtype=dtype_dict[observation_space[key].dtype.type], device=self.device)
for key, _obs_shape in self.obs_shape.items()
}
self._actions = th.zeros(
(self.buffer_size * self.n_envs, self.action_dim), dtype=dtype_dict[action_space.dtype.type], device=self.device)
self._returns = th.zeros((self.buffer_size * self.n_envs,), dtype=th.float32, device=self.device)
obs_nbytes = 0
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.returns.nbytes
total_memory_usage = total_memory_usage / 1e9 * 2
print(f"Total gpu memory usage: {total_memory_usage:.2f}GB")
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(
self,
obs: th.Tensor,
action: th.Tensor,
reward: np.ndarray,
to_play: Optional[np.ndarray] = None,
) -> None:
batch_size = reward.shape[0]
for key in self.observations.keys():
self.observations[key][self.pos] = obs[key]
self.actions[self.pos] = action.reshape((batch_size, self.action_dim))
self.rewards[self.pos] = np.array(reward)
if self.selfplay:
self.to_play[self.pos] = to_play
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def mark_episode(self, env_ind, gamma):
start = self.start[env_ind]
pos = self.pos
if pos <= start:
end = pos + self.buffer_size
batch_inds = np.arange(start, end) % self.buffer_size
else:
batch_inds = np.arange(start, pos)
self.start[env_ind] = pos
if self.selfplay:
returns = nstep_return_selfplay(self.rewards[batch_inds, env_ind], self.to_play[batch_inds, env_ind], gamma)
else:
returns = nstep_return(self.rewards[batch_inds, env_ind], gamma)
self.returns[batch_inds, env_ind] = th.from_numpy(returns).to(self.device)
def get_data_indices(self):
if not self.full:
indices = np.arange(self.start.min())
# print(0, self.start.min(), self.pos, self.start, self.full)
else:
start = self.pos
end = np.where(self.pos >= self.start, self.start + self.buffer_size, self.start).min()
indices = np.arange(start, end) % self.buffer_size
# print(start, end, self.pos, self.start, self.full)
return indices
def _get_samples(self, batch_inds: np.ndarray):
l = len(batch_inds) * self.n_envs
for key, obs in self.observations.items():
_obs = self._observations[key]
_obs[:l, :] = obs[batch_inds, :, :].flatten(0, 1)
self._actions[:l, :] = self.actions[batch_inds, :, :].reshape(-1, self.action_dim)
self._returns[:l] = self.returns[batch_inds, :].reshape(-1)
data = (
{key: _obs[:l] for key, _obs in self._observations.items()},
self._actions[:l],
self._returns[:l],
)
return data
class DictReplayBuffer:
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]]
observations: Dict[str, np.ndarray]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space)
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
mem_available = psutil.virtual_memory().available
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
obs_nbytes = 0
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
) -> None:
# Copy to avoid modification by reference
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = np.array(obs[key])
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample( # type: ignore[override]
self,
batch_size: int,
) -> DictReplayBufferSamples:
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:return:
"""
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds)
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
) -> DictReplayBufferSamples:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
obs_ = {key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}
assert isinstance(obs_, dict)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_inds, env_indices]),
dones=self.to_torch(self.dones[batch_inds, env_indices]).reshape(-1, 1),
rewards=self.to_torch(self.rewards[batch_inds, env_indices].reshape(-1, 1)),
)
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
:return:
"""
if copy:
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
import re
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
def split_param_groups(model, regex):
embed_params = []
other_params = []
for name, param in model.named_parameters():
if re.search(regex, name):
embed_params.append(param)
else:
other_params.append(param)
return [
{'params': embed_params}, {'params': other_params}
]
class Elo:
def __init__(self, k = 10, r0 = 1500, r1 = 1500):
self.r0 = r0
self.r1 = r1
self.k = k
def update(self, winner):
diff = self.k * (1 - self.expect_result(self.r0, self.r1))
if winner == 1:
diff = -diff
self.r0 += diff
self.r1 -= diff
def expect_result(self, p0, p1):
exp = (p0 - p1) / 400.0
return 1 / ((10.0 ** (exp)) + 1)
\ No newline at end of file
import itertools
from pathlib import Path
from ygoenv.ygopro import init_module
def load_deck(fn):
with open(fn) as f:
lines = f.readlines()
noside = itertools.takewhile(lambda x: "side" not in x, lines)
deck = [int(line) for line in noside if line[:-1].isdigit()]
return deck
def get_root_directory():
cur = Path(__file__).resolve()
return str(cur.parent.parent)
def extract_deck_name(path):
return Path(path).stem
_languages = {
"english": "en",
"chinese": "zh",
}
def init_ygopro(lang, deck, code_list_file, preload_tokens=True):
short = _languages[lang]
db_path = Path(get_root_directory(), 'assets', 'locale', short, 'cards.cdb')
deck_fp = Path(deck)
if deck_fp.is_dir():
decks = {f.stem: str(f) for f in deck_fp.glob("*.ydk")}
deck_dir = deck_fp
deck_name = 'random'
else:
deck_name = deck_fp.stem
decks = {deck_name: deck}
deck_dir = deck_fp.parent
if preload_tokens:
token_deck = deck_dir / "_tokens.ydk"
if not token_deck.exists():
raise FileNotFoundError(f"Token deck not found: {token_deck}")
decks["_tokens"] = str(token_deck)
init_module(str(db_path), code_list_file, decks)
return deck_name
\ No newline at end of file
include ygoenv/*/*.so
\ No newline at end of file
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
from setuptools import setup, find_packages
__version__ = "0.0.1"
INSTALL_REQUIRES = [
"setuptools",
"wheel",
"numpy",
"dm-env",
"gym>=0.26",
"gymnasium>=0.26,!=0.27.0",
"optree>=0.6.0",
"packaging",
]
setup(
name="ygoenv",
version=__version__,
packages=find_packages(include='ygoenv*'),
long_description="",
install_requires=INSTALL_REQUIRES,
python_requires=">=3.7",
include_package_data=True,
)
\ No newline at end of file
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool package for efficient RL environment simulation."""
import ygoenv.entry # noqa: F401
from ygoenv.registration import (
list_all_envs,
make,
make_dm,
make_gym,
make_gymnasium,
make_spec,
register,
)
__version__ = "0.8.4"
__all__ = [
"register",
"make",
"make_dm",
"make_gym",
"make_gymnasium",
"make_spec",
"list_all_envs",
]
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif
\ No newline at end of file
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
#define YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
#ifndef MOODYCAMEL_DELETE_FUNCTION
#define MOODYCAMEL_DELETE_FUNCTION = delete
#endif
#include <atomic>
#include <cassert>
#include <utility>
#include <vector>
#include "ygoenv/core/array.h"
#include "concurrentqueue/moodycamel/lightweightsemaphore.h"
/**
* Lock-free action buffer queue.
*/
class ActionBufferQueue {
public:
struct ActionSlice {
int env_id;
int order;
bool force_reset;
};
protected:
std::atomic<uint64_t> alloc_ptr_, done_ptr_;
std::size_t queue_size_;
std::vector<ActionSlice> queue_;
moodycamel::LightweightSemaphore sem_, sem_enqueue_, sem_dequeue_;
public:
explicit ActionBufferQueue(std::size_t num_envs)
: alloc_ptr_(0),
done_ptr_(0),
queue_size_(num_envs * 2),
queue_(queue_size_),
sem_(0),
sem_enqueue_(1),
sem_dequeue_(1) {}
void EnqueueBulk(const std::vector<ActionSlice>& action) {
// ensure only one enqueue_bulk happens at any time
while (!sem_enqueue_.wait()) {
}
uint64_t pos = alloc_ptr_.fetch_add(action.size());
for (std::size_t i = 0; i < action.size(); ++i) {
queue_[(pos + i) % queue_size_] = action[i];
}
sem_.signal(action.size());
sem_enqueue_.signal(1);
}
ActionSlice Dequeue() {
while (!sem_.wait()) {
}
while (!sem_dequeue_.wait()) {
}
auto ptr = done_ptr_.fetch_add(1);
auto ret = queue_[ptr % queue_size_];
sem_dequeue_.signal(1);
return ret;
}
std::size_t SizeApprox() {
return static_cast<std::size_t>(alloc_ptr_ - done_ptr_);
}
};
#endif // YGOENV_CORE_ACTION_BUFFER_QUEUE_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ARRAY_H_
#define YGOENV_CORE_ARRAY_H_
#include <glog/logging.h>
#include <cstddef>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "ygoenv/core/spec.h"
class Array {
public:
std::size_t size;
std::size_t ndim;
std::size_t element_size;
protected:
std::vector<std::size_t> shape_;
std::shared_ptr<char> ptr_;
template <class Shape, class Deleter>
Array(char* ptr, Shape&& shape, std::size_t element_size, // NOLINT
Deleter&& deleter)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(ptr, std::forward<Deleter>(deleter)) {}
template <class Shape>
Array(std::shared_ptr<char> ptr, Shape&& shape, std::size_t element_size)
: size(Prod(shape.data(), shape.size())),
ndim(shape.size()),
element_size(element_size),
shape_(std::forward<Shape>(shape)),
ptr_(std::move(ptr)) {}
public:
Array() = default;
/**
* Constructor an `Array` of shape defined by `spec`, with `data` as pointer
* to its raw memory. With an empty deleter, which means Array does not own
* the memory.
*/
template <class Deleter>
Array(const ShapeSpec& spec, char* data, Deleter&& deleter) // NOLINT
: Array(data, spec.Shape(), spec.element_size,
std::forward<Deleter>(deleter)) {}
Array(const ShapeSpec& spec, char* data)
: Array(data, spec.Shape(), spec.element_size, [](char* /*unused*/) {}) {}
/**
* Constructor an `Array` of shape defined by `spec`. This constructor
* allocates and owns the memory.
*/
explicit Array(const ShapeSpec& spec)
: Array(spec, nullptr, [](char* /*unused*/) {}) {
ptr_.reset(new char[size * element_size](),
[](const char* p) { delete[] p; });
}
/**
* Take multidimensional index into the Array.
*/
template <typename... Index>
inline Array operator()(Index... index) const {
constexpr std::size_t num_index = sizeof...(Index);
DCHECK_GE(ndim, num_index);
std::size_t offset = 0;
std::size_t i = 0;
for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) {
offset *= shape_[i];
}
return Array(
ptr_.get() + offset * element_size,
std::vector<std::size_t>(shape_.begin() + num_index, shape_.end()),
element_size, [](char* /*unused*/) {});
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline Array operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] Array Slice(std::size_t start, std::size_t end) const {
DCHECK_GT(ndim, (std::size_t)0);
CHECK_GE(shape_[0], end);
CHECK_GE(end, start);
std::vector<std::size_t> new_shape(shape_);
new_shape[0] = end - start;
std::size_t offset = 0;
if (shape_[0] > 0) {
offset = start * size / shape_[0];
}
return {ptr_.get() + offset * element_size, std::move(new_shape),
element_size, [](char* p) {}};
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const Array& value) const {
DCHECK_EQ(element_size, value.element_size)
<< " element size doesn't match";
DCHECK_EQ(size, value.size) << " ndim doesn't match";
std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size);
}
/**
* Assign to this Array a scalar value. This Array needs to have a scalar
* shape.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, Array>, bool> = true>
void operator=(const T& value) const {
DCHECK_EQ(element_size, sizeof(T)) << " element size doesn't match";
DCHECK_EQ(size, (std::size_t)1) << " assigning scalar to non-scalar array";
*reinterpret_cast<T*>(ptr_.get()) = value;
}
/**
* Fills this array with a scalar value of type T.
*/
template <typename T>
void Fill(const T& value) const {
DCHECK_EQ(element_size, sizeof(T)) << " element size doesn't match";
auto* data = reinterpret_cast<T*>(ptr_.get());
std::fill(data, data + size, value);
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
template <typename T>
void Assign(const T* buff, std::size_t sz) const {
DCHECK_EQ(sz, size) << " assignment size mismatch";
DCHECK_EQ(sizeof(T), element_size) << " element size mismatch";
std::memcpy(ptr_.get(), buff, sz * sizeof(T));
}
/**
* Size of axis `dim`.
*/
[[nodiscard]] inline std::size_t Shape(std::size_t dim) const {
return shape_[dim];
}
/**
* Shape
*/
[[nodiscard]] inline const std::vector<std::size_t>& Shape() const {
return shape_;
}
/**
* Pointer to the raw memory.
*/
[[nodiscard]] inline void* Data() const { return ptr_.get(); }
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] Array Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
Array ret(ptr_, std::move(new_shape), element_size);
return ret;
}
void Zero() const { std::memset(ptr_.get(), 0, size * element_size); }
[[nodiscard]] std::shared_ptr<char> SharedPtr() const { return ptr_; }
};
template <typename Dtype>
class TArray : public Array {
template <class Shape, class Deleter>
TArray(char* ptr, Shape&& shape, std::size_t element_size, // NOLINT
Deleter&& deleter)
: Array(ptr, shape, element_size, deleter) {}
template <class Shape>
TArray(const std::shared_ptr<char>& ptr, Shape&& shape,
std::size_t element_size)
: Array(ptr, shape, element_size) {}
public:
TArray() = default;
explicit TArray(const Spec<Dtype>& spec) : Array(spec) {}
explicit TArray(const Spec<Dtype>& spec, const char* data)
: Array(spec, data) {}
template <typename A, std::enable_if_t<std::is_same_v<std::decay_t<A>, Array>,
bool> = true>
explicit TArray(A&& array) : Array(std::forward<A>(array)) { // NOLINT
DCHECK_EQ(array.element_size, sizeof(Dtype));
}
/**
* Take multidimensional index into the Array.
*/
template <typename... Index>
inline TArray operator()(Index... index) const {
return TArray(Array::operator()(index...));
}
/**
* Index operator of array, takes the index along the first axis.
*/
inline TArray operator[](int index) const { return this->operator()(index); }
/**
* Take a slice at the first axis of the Array.
*/
[[nodiscard]] TArray Slice(std::size_t start, std::size_t end) const {
return TArray(Array::Slice(start, end));
}
/**
* Copy the content of another Array to this Array.
*/
void Assign(const TArray& value) const { Array::Assign(value); }
void Assign(const Array& value) const { Array::Assign(value); }
/**
* Assign a scalar value.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, TArray>, bool> = true>
void operator=(const T& value) const {
*reinterpret_cast<Dtype*>(ptr_.get()) = static_cast<Dtype>(value);
}
/**
* Fills this array with a scalar value of type T.
*/
template <typename T>
void Fill(const T& value) const {
auto data = reinterpret_cast<Dtype*>(ptr_.get());
std::fill(data, data + size, static_cast<Dtype>(value));
}
/**
* Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the
* memory of this Array.
*/
void Assign(const Dtype* buff, std::size_t sz) const {
std::memcpy(ptr_.get(), buff, sz * sizeof(Dtype));
}
operator Dtype&() const { // NOLINT
return *reinterpret_cast<Dtype*>(ptr_.get());
}
/**
* Cast the Array to a scalar value of type `T`. This Array needs to have a
* scalar shape.
*/
template <typename T,
std::enable_if_t<!std::is_same_v<T, Dtype>, bool> = true>
operator T() const { // NOLINT
DCHECK_EQ(size, (std::size_t)1)
<< " Array with a non-scalar shape can't be used as a scalar";
return static_cast<T>(*reinterpret_cast<Dtype*>(ptr_.get()));
}
/**
* Truncate the Array. Return a new Array that shares the same memory
* location but with a truncated shape.
*/
[[nodiscard]] TArray Truncate(std::size_t end) const {
auto new_shape = std::vector<std::size_t>(shape_);
new_shape[0] = end;
TArray ret(ptr_, std::move(new_shape), element_size);
return ret;
}
};
#endif // YGOENV_CORE_ARRAY_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ASYNC_ENVPOOL_H_
#define YGOENV_CORE_ASYNC_ENVPOOL_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <thread>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "ygoenv/core/action_buffer_queue.h"
#include "ygoenv/core/array.h"
#include "ygoenv/core/envpool.h"
#include "ygoenv/core/state_buffer_queue.h"
/**
* Async EnvPool
*
* batch-action -> action buffer queue -> threadpool -> state buffer queue
*
* ThreadPool is tailored with EnvPool, so here we don't use the existing
* third_party ThreadPool (which is really slow).
*/
template <typename Env>
class AsyncEnvPool : public EnvPool<typename Env::Spec> {
protected:
std::size_t num_envs_;
std::size_t batch_;
std::size_t max_num_players_;
std::size_t num_threads_;
bool is_sync_;
std::atomic<int> stop_;
std::atomic<std::size_t> stepping_env_num_;
std::vector<std::thread> workers_;
std::unique_ptr<ActionBufferQueue> action_buffer_queue_;
std::unique_ptr<StateBufferQueue> state_buffer_queue_;
std::vector<std::unique_ptr<Env>> envs_;
std::vector<std::atomic<int>> stepping_env_;
std::chrono::duration<double> dur_send_, dur_recv_, dur_send_all_;
template <typename V>
void SendImpl(V&& action) {
int* env_id = static_cast<int*>(action[0].Data());
int shared_offset = action[0].Shape(0);
std::vector<ActionSlice> actions;
std::shared_ptr<std::vector<Array>> action_batch =
std::make_shared<std::vector<Array>>(std::forward<V>(action));
for (int i = 0; i < shared_offset; ++i) {
int eid = env_id[i];
envs_[eid]->SetAction(action_batch, i);
actions.emplace_back(ActionSlice{
.env_id = eid,
.order = is_sync_ ? i : -1,
.force_reset = false,
});
}
if (is_sync_) {
stepping_env_num_ += shared_offset;
}
// add to abq
auto start = std::chrono::system_clock::now();
action_buffer_queue_->EnqueueBulk(actions);
dur_send_ += std::chrono::system_clock::now() - start;
}
public:
using Spec = typename Env::Spec;
using Action = typename Env::Action;
using State = typename Env::State;
using ActionSlice = typename ActionBufferQueue::ActionSlice;
explicit AsyncEnvPool(const Spec& spec)
: EnvPool<Spec>(spec),
num_envs_(spec.config["num_envs"_]),
batch_(spec.config["batch_size"_] <= 0 ? num_envs_
: spec.config["batch_size"_]),
max_num_players_(spec.config["max_num_players"_]),
num_threads_(spec.config["num_threads"_]),
is_sync_(batch_ == num_envs_ && max_num_players_ == 1),
stop_(0),
stepping_env_num_(0),
action_buffer_queue_(new ActionBufferQueue(num_envs_)),
state_buffer_queue_(new StateBufferQueue(
batch_, num_envs_, max_num_players_,
spec.state_spec.template AllValues<ShapeSpec>())),
envs_(num_envs_) {
std::size_t processor_count = std::thread::hardware_concurrency();
ThreadPool init_pool(std::min(processor_count, num_envs_));
std::vector<std::future<void>> result;
for (std::size_t i = 0; i < num_envs_; ++i) {
result.emplace_back(init_pool.enqueue(
[i, spec, this] { envs_[i].reset(new Env(spec, i)); }));
}
for (auto& f : result) {
f.get();
}
if (num_threads_ == 0) {
num_threads_ = std::min(batch_, processor_count);
}
for (std::size_t i = 0; i < num_threads_; ++i) {
workers_.emplace_back([this] {
for (;;) {
ActionSlice raw_action = action_buffer_queue_->Dequeue();
if (stop_ == 1) {
break;
}
int env_id = raw_action.env_id;
int order = raw_action.order;
bool reset = raw_action.force_reset || envs_[env_id]->IsDone();
envs_[env_id]->EnvStep(state_buffer_queue_.get(), order, reset);
}
});
}
if (spec.config["thread_affinity_offset"_] >= 0) {
std::size_t thread_affinity_offset =
spec.config["thread_affinity_offset"_];
for (std::size_t tid = 0; tid < num_threads_; ++tid) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
std::size_t cid = (thread_affinity_offset + tid) % processor_count;
CPU_SET(cid, &cpuset);
pthread_setaffinity_np(workers_[tid].native_handle(), sizeof(cpu_set_t),
&cpuset);
}
}
}
~AsyncEnvPool() override {
stop_ = 1;
// LOG(INFO) << "envpool send: " << dur_send_.count();
// LOG(INFO) << "envpool recv: " << dur_recv_.count();
// send n actions to clear threadpool
std::vector<ActionSlice> empty_actions(workers_.size());
action_buffer_queue_->EnqueueBulk(empty_actions);
for (auto& worker : workers_) {
worker.join();
}
}
void Send(const Action& action) {
SendImpl(action.template AllValues<Array>());
}
void Send(const std::vector<Array>& action) override { SendImpl(action); }
void Send(std::vector<Array>&& action) override { SendImpl(action); }
std::vector<Array> Recv() override {
int additional_wait = 0;
if (is_sync_ && stepping_env_num_ < batch_) {
additional_wait = batch_ - stepping_env_num_;
}
auto start = std::chrono::system_clock::now();
auto ret = state_buffer_queue_->Wait(additional_wait);
dur_recv_ += std::chrono::system_clock::now() - start;
if (is_sync_) {
stepping_env_num_ -= ret[0].Shape(0);
}
return ret;
}
void Reset(const Array& env_ids) override {
TArray<int> tenv_ids(env_ids);
int shared_offset = tenv_ids.Shape(0);
std::vector<ActionSlice> actions(shared_offset);
for (int i = 0; i < shared_offset; ++i) {
actions[i].force_reset = true;
actions[i].env_id = tenv_ids[i];
actions[i].order = is_sync_ ? i : -1;
}
if (is_sync_) {
stepping_env_num_ += shared_offset;
}
action_buffer_queue_->EnqueueBulk(actions);
}
};
#endif // YGOENV_CORE_ASYNC_ENVPOOL_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_CIRCULAR_BUFFER_H_
#define YGOENV_CORE_CIRCULAR_BUFFER_H_
#ifndef MOODYCAMEL_DELETE_FUNCTION
#define MOODYCAMEL_DELETE_FUNCTION = delete
#endif
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <utility>
#include <vector>
#include "concurrentqueue/moodycamel/lightweightsemaphore.h"
template <typename V>
class CircularBuffer {
protected:
std::size_t size_;
moodycamel::LightweightSemaphore sem_get_;
moodycamel::LightweightSemaphore sem_put_;
std::vector<V> buffer_;
std::atomic<uint64_t> head_;
std::atomic<uint64_t> tail_;
public:
explicit CircularBuffer(std::size_t size)
: size_(size), sem_put_(size), buffer_(size), head_(0), tail_(0) {}
template <typename T>
void Put(T&& v) {
while (!sem_put_.wait()) {
}
uint64_t tail = tail_.fetch_add(1);
auto offset = tail % size_;
buffer_[offset] = std::forward<T>(v);
sem_get_.signal();
}
V Get() {
while (!sem_get_.wait()) {
}
uint64_t head = head_.fetch_add(1);
auto offset = head % size_;
V v = std::move(buffer_[offset]);
sem_put_.signal();
return v;
}
};
#endif // YGOENV_CORE_CIRCULAR_BUFFER_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_DICT_H_
#define YGOENV_CORE_DICT_H_
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "ygoenv/core/array.h"
#include "ygoenv/core/spec.h"
#include "ygoenv/core/tuple_utils.h"
#include "ygoenv/core/type_utils.h"
template <typename K, typename D>
class Value {
public:
using Key = K;
using Type = D;
explicit Value(Type&& v) : v(v) {}
Type v;
};
template <char... C>
class Key {
public:
static constexpr const inline char kStr[sizeof...(C) + 1]{C..., // NOLINT
'\0'};
static constexpr const inline std::string_view kStrView{kStr, sizeof...(C)};
template <typename Type>
static constexpr inline auto Bind(Type&& v) {
return Value<Key, Type>(std::forward<Type>(v));
}
static inline std::string Str() { return {kStrView.data(), kStrView.size()}; }
};
template <class CharT, CharT... CS>
inline constexpr auto operator""_() { // NOLINT
return Key<CS...>{};
}
template <
typename Key, typename Keys, typename TupleOrVector,
std::enable_if_t<is_tuple_v<std::decay_t<TupleOrVector>>, bool> = true>
inline decltype(auto) Take(const Key& key, TupleOrVector&& values) {
constexpr std::size_t index = Index<Key, Keys>::kValue;
return std::get<index>(std::forward<TupleOrVector>(values));
}
template <
typename Key, typename Keys, typename TupleOrVector,
std::enable_if_t<is_vector_v<std::decay_t<TupleOrVector>>, bool> = true>
inline decltype(auto) Take(const Key& key, TupleOrVector&& values) {
constexpr std::size_t index = Index<Key, Keys>::kValue;
return std::forward<TupleOrVector>(values).at(index);
}
template <typename StringKeys, typename Vector,
std::enable_if_t<is_vector_v<std::decay_t<Vector>>, bool> = true>
class NamedVector {
protected:
Vector* values_;
public:
using Keys = StringKeys;
constexpr static std::size_t kSize = std::tuple_size<Keys>::value;
explicit NamedVector(Vector* values) : values_(values) {}
NamedVector(const Keys& keys, Vector* values) : values_(values) {}
template <typename Key,
std::enable_if_t<any_match<Key, Keys>::value, bool> = true>
inline decltype(auto) operator[](const Key& key) const {
return Take<Key, Keys, Vector&>(key, *values_);
}
/**
* Return a static constexpr list of all the keys in a tuple.
*/
static constexpr decltype(auto) StaticKeys() { return Keys(); }
/**
* Return a list of all the keys as strings.
*/
static std::vector<std::string> AllKeys() {
std::vector<std::string> rets;
std::apply([&](auto&&... key) { (rets.push_back(key.str()), ...); },
Keys());
return rets;
}
operator Vector&() const { // NOLINT
return *values_;
}
};
template <typename StringKeys, typename TupleOrVector,
typename = std::enable_if_t<is_tuple_v<StringKeys>>>
class Dict : public std::decay_t<TupleOrVector> {
public:
using Values = std::decay_t<TupleOrVector>;
using Keys = StringKeys;
constexpr static std::size_t kSize = std::tuple_size<Keys>::value;
/**
* Check that the size of values / keys tuple should match
*/
template <typename V = Values, std::enable_if_t<is_tuple_v<V>, bool> = true>
void Check() {
static_assert(std::tuple_size_v<Keys> == std::tuple_size_v<Values>,
"Number of keys and values doesn't match");
}
template <typename V = Values, std::enable_if_t<is_vector_v<V>, bool> = true>
void Check() {
DCHECK_EQ(std::tuple_size<Keys>(), static_cast<V*>(this)->size())
<< "Size must match";
}
Dict() = default;
/**
* Constructor, makes a dict from keys and values
*/
Dict(const Keys& keys, TupleOrVector&& values) : Values(std::move(values)) {
Check();
}
Dict(const Keys& keys, const TupleOrVector& values) : Values(values) {
Check();
}
/**
* Constructor, needs to be called with template types
*/
explicit Dict(TupleOrVector&& values) : Values(std::move(values)) { Check(); }
explicit Dict(const TupleOrVector& values) : Values(values) { Check(); }
template <typename V, typename V2 = Values,
std::enable_if_t<is_vector_v<std::decay_t<V>>, bool> = true,
std::enable_if_t<is_tuple_v<V2>, bool> = true>
explicit Dict(V&& values) // NOLINT
: Dict(TupleFromVector<Values>(std::forward<V>(values))) {}
/**
* Gives the values a [index] based accessor
* converts the string literal to a compile time index, and use
* std::get<index> to get it from the base class.
* If the key doesn't exists in the keys, compilation will fail.
*/
template <typename Key,
std::enable_if_t<any_match<Key, Keys>::value, bool> = true>
inline decltype(auto) operator[](const Key& key) {
return Take<Key, Keys, Values&>(key, *this);
}
template <typename Key,
std::enable_if_t<any_match<Key, Keys>::value, bool> = true>
inline decltype(auto) operator[](const Key& key) const {
return Take<Key, Keys, const Values&>(key, *this);
}
/**
* Return a static constexpr list of all the keys in a tuple.
*/
static constexpr decltype(auto) StaticKeys() { return Keys(); }
/**
* Return a list of all the keys as strings.
*/
static std::vector<std::string> AllKeys() {
std::vector<std::string> rets;
std::apply([&](auto&&... key) { (rets.push_back(key.Str()), ...); },
Keys());
return rets;
}
/**
* Return a static list of all the values in a tuple.
*/
Values& AllValues() { return *this; }
/**
* Const version of static_values
*/
[[nodiscard]] const Values& AllValues() const { return *this; }
/**
* Convert the value tuple to a dynamic vector of values.
* This function is only enabled when Values is instantiation of std::tuple,
* and when all elements in the values can be converted to Type
*/
template <typename Type, bool IsTuple = is_tuple_v<Values>,
std::enable_if_t<IsTuple, bool> = true,
std::enable_if_t<all_convertible<Type, Values>::value, bool> = true>
[[nodiscard]] std::vector<Type> AllValues() const {
std::vector<Type> rets;
std::apply(
[&](auto&&... value) {
(rets.push_back(static_cast<Type>(value)), ...);
},
*static_cast<const Values*>(this));
return rets;
}
/**
* Convert the value vector to a vector of type `Type`.
* This function is only enabled when Values is an instantiation of
* std::vector.
*/
template <typename Type, bool IsTuple = is_tuple_v<Values>,
std::enable_if_t<!IsTuple, bool> = true>
std::vector<Type> AllValues() const {
return std::vector<Type>(this->begin(), this->end());
}
template <class F, bool IsTuple = is_tuple_v<Values>,
std::enable_if_t<IsTuple, bool> = true>
decltype(auto) Apply(F&& f) const {
ApplyZip(f, Keys(), *this, std::make_index_sequence<kSize>{});
}
};
/**
* Make a dict which is actually an namedtuple in cpp
* Syntax is like
* auto d = MakeDict("abc"_.Bind(0.), "xyz"_.Bind(0.), "ijk"_.Bind(1));
* The above makes a dict { "abc": 0., "xyz": 0., "ijk": 1 }
*/
template <typename... Value>
decltype(auto) MakeDict(Value... v) {
return Dict(std::make_tuple(typename Value::Key()...),
std::make_tuple(v.v...));
}
template <
typename DictA, typename DictB,
typename AllKeys = tuple_cat_t<typename DictA::Keys, typename DictB::Keys>,
std::enable_if_t<is_tuple_v<typename DictA::Values> &&
is_tuple_v<typename DictB::Values>,
bool> = true>
decltype(auto) ConcatDict(const DictA& a, const DictB& b) {
auto c = std::tuple_cat(static_cast<const typename DictA::Values&>(a),
static_cast<const typename DictB::Values&>(b));
return Dict<AllKeys, decltype(c)>(std::move(c));
}
template <
typename DictA, typename DictB,
typename AllKeys = tuple_cat_t<typename DictA::Keys, typename DictB::Keys>,
std::enable_if_t<is_vector_v<DictA> && is_vector_v<DictB>, bool> = true,
std::enable_if_t<
std::is_same_v<typename DictA::Values, typename DictA::Values>, bool> =
true>
decltype(auto) ConcatDict(const DictA& a, const DictB& b) {
std::vector<typename DictA::Values::value_type> c;
c.insert(c.end(), a.begin(), a.end());
c.insert(c.end(), b.begin(), b.end());
return Dict<AllKeys, decltype(c)>(c);
}
/**
* Transform an input vector into an output vector.
* calls std::transform, infer the vector that needs to be created from the
* transform function.
*/
template <typename S, typename F,
typename R = decltype(std::declval<F>()(std::declval<S>()))>
std::vector<R> Transform(const std::vector<S>& src, F&& transform) {
std::vector<R> tgt;
std::transform(src.begin(), src.end(), std::back_inserter(tgt),
std::forward<F>(transform));
return tgt;
}
/**
* Static version of MakeArray.
* Takes a tuple of `template <typename T> Spec<T>`.
*/
template <typename... Spec>
std::vector<Array> MakeArray(const std::tuple<Spec...>& specs) {
std::vector<Array> rets;
std::apply([&](auto&&... spec) { (rets.push_back(Array(spec)), ...); },
specs);
return rets;
}
/**
* Dynamic version of MakeArray.
* Takes a vector of `ShapeSpec`.
*/
std::vector<Array> MakeArray(const std::vector<ShapeSpec>& specs) {
return {specs.begin(), specs.end()};
}
#endif // YGOENV_CORE_DICT_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ENV_H_
#define YGOENV_CORE_ENV_H_
#include <memory>
#include <random>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "ygoenv/core/env_spec.h"
#include "ygoenv/core/state_buffer_queue.h"
template <typename Dtype>
struct InitializeHelper {
static void Init(Array* arr) {}
};
template <typename Dtype>
struct InitializeHelper<Container<Dtype>> {
static void Init(Array* arr) {
auto* carr = reinterpret_cast<Container<Dtype>*>(arr->Data());
for (std::size_t i = 0; i < arr->size; ++i) {
new (carr + i) Container<Dtype>(nullptr);
}
}
};
template <typename Spec>
void InplaceInitialize(const Spec& spec, Array* arr) {
InitializeHelper<typename Spec::dtype>::Init(arr);
}
template <typename SpecTuple>
struct SpecToTArray;
template <typename... Args>
struct SpecToTArray<std::tuple<Args...>> {
using Type = std::tuple<TArray<typename Args::dtype>...>;
};
/**
* Single RL environment abstraction.
*/
template <typename EnvSpec>
class Env {
protected:
int max_num_players_;
EnvSpec spec_;
int env_id_, seed_;
std::mt19937 gen_;
private:
StateBufferQueue* sbq_;
int order_, current_step_{-1};
bool is_single_player_;
StateBuffer::WritableSlice slice_;
// for parsing single env action from input action batch
std::vector<ShapeSpec> action_specs_;
std::vector<bool> is_player_action_;
std::shared_ptr<std::vector<Array>> action_batch_;
std::vector<Array> raw_action_;
int env_index_;
public:
using Spec = EnvSpec;
using State =
Dict<typename EnvSpec::StateKeys,
typename SpecToTArray<typename EnvSpec::StateSpec::Values>::Type>;
using Action =
Dict<typename EnvSpec::ActionKeys,
typename SpecToTArray<typename EnvSpec::ActionSpec::Values>::Type>;
Env(const EnvSpec& spec, int env_id)
: max_num_players_(spec.config["max_num_players"_]),
spec_(spec),
env_id_(env_id),
seed_(spec.config["seed"_] + env_id),
gen_(seed_),
is_single_player_(max_num_players_ == 1),
action_specs_(spec.action_spec.template AllValues<ShapeSpec>()),
is_player_action_(Transform(action_specs_, [](const ShapeSpec& s) {
return (!s.shape.empty() && s.shape[0] == -1);
})) {
slice_.done_write = [] { LOG(INFO) << "Use `Allocate` to write state."; };
}
virtual ~Env() = default;
void SetAction(std::shared_ptr<std::vector<Array>> action_batch,
int env_index) {
action_batch_ = std::move(action_batch);
env_index_ = env_index;
}
void ParseAction() {
raw_action_.clear();
std::size_t action_size = action_batch_->size();
if (is_single_player_) {
for (std::size_t i = 0; i < action_size; ++i) {
if (is_player_action_[i]) {
raw_action_.emplace_back(
(*action_batch_)[i].Slice(env_index_, env_index_ + 1));
} else {
raw_action_.emplace_back((*action_batch_)[i][env_index_]);
}
}
} else {
std::vector<int> env_player_index;
int* player_env_id = static_cast<int*>((*action_batch_)[1].Data());
int player_offset = (*action_batch_)[1].Shape(0);
for (int i = 0; i < player_offset; ++i) {
if (player_env_id[i] == env_id_) {
env_player_index.push_back(i);
}
}
int player_num = env_player_index.size();
bool continuous = false;
int start = 0;
int end = 0;
if (player_num > 0) {
start = env_player_index[0];
end = env_player_index[player_num - 1] + 1;
continuous = (player_num == end - start);
}
for (std::size_t i = 0; i < action_size; ++i) {
if (is_player_action_[i]) {
if (continuous) {
raw_action_.emplace_back((*action_batch_)[i].Slice(start, end));
} else {
action_specs_[i].shape[0] = player_num;
Array arr(action_specs_[i]);
for (int j = 0; j < player_num; ++j) {
int player_index = env_player_index[j];
arr[j].Assign((*action_batch_)[i][player_index]);
}
raw_action_.emplace_back(std::move(arr));
}
} else {
raw_action_.emplace_back((*action_batch_)[i][env_index_]);
}
}
}
}
void EnvStep(StateBufferQueue* sbq, int order, bool reset) {
PreProcess(sbq, order, reset);
if (reset) {
Reset();
} else {
ParseAction();
Step(Action(std::move(raw_action_)));
raw_action_.clear();
}
PostProcess();
}
virtual void Reset() { throw std::runtime_error("reset not implemented"); }
virtual void Step(const Action& action) {
throw std::runtime_error("step not implemented");
}
virtual bool IsDone() { throw std::runtime_error("is_done not implemented"); }
protected:
void PreProcess(StateBufferQueue* sbq, int order, bool reset) {
sbq_ = sbq;
order_ = order;
if (reset) {
current_step_ = 0;
} else {
++current_step_;
}
}
void PostProcess() {
slice_.done_write();
// action_batch_.reset();
}
State Allocate(int player_num = 1) {
slice_ = sbq_->Allocate(player_num, order_);
State state(slice_.arr);
bool done = IsDone();
int max_episode_steps = spec_.config["max_episode_steps"_];
state["done"_] = done;
state["discount"_] = static_cast<float>(!done);
// dm_env.StepType.FIRST == 0
// dm_env.StepType.MID == 1
// dm_env.StepType.LAST == 2
state["step_type"_] = current_step_ == 0 ? 0 : done ? 2 : 1;
state["trunc"_] = done && (current_step_ >= max_episode_steps);
state["info:env_id"_] = env_id_;
state["elapsed_step"_] = current_step_;
int* player_env_id(static_cast<int*>(state["info:players.env_id"_].Data()));
for (int i = 0; i < player_num; ++i) {
player_env_id[i] = env_id_;
}
// Inplace initialize all container fields
int i = 0;
std::apply(
[&](auto&&... spec) {
(InplaceInitialize(spec, &slice_.arr[i++]), ...);
},
spec_.state_spec.AllValues());
return state;
}
};
#endif // YGOENV_CORE_ENV_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ENV_SPEC_H_
#define YGOENV_CORE_ENV_SPEC_H_
#include <limits>
#include <string>
#include "ygoenv/core/dict.h"
auto common_config =
MakeDict("num_envs"_.Bind(1), "batch_size"_.Bind(0), "num_threads"_.Bind(0),
"max_num_players"_.Bind(1), "thread_affinity_offset"_.Bind(-1),
"base_path"_.Bind(std::string("ygoenv")), "seed"_.Bind(42),
"gym_reset_return_info"_.Bind(false),
"max_episode_steps"_.Bind(std::numeric_limits<int>::max()));
// Note: this action order is hardcoded in async_envpool Send function
// and env ParseAction function for performance
auto common_action_spec = MakeDict("env_id"_.Bind(Spec<int>({})),
"players.env_id"_.Bind(Spec<int>({-1})));
// Note: this state order is hardcoded in async_envpool Recv function
auto common_state_spec =
MakeDict("info:env_id"_.Bind(Spec<int>({})),
"info:players.env_id"_.Bind(Spec<int>({-1})),
"elapsed_step"_.Bind(Spec<int>({})), "done"_.Bind(Spec<bool>({})),
"reward"_.Bind(Spec<float>({-1})),
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})),
"step_type"_.Bind(Spec<int>({})), "trunc"_.Bind(Spec<bool>({})));
/**
* EnvSpec funciton, it constructs the env spec when a Config is passed.
*/
template <typename EnvFns>
class EnvSpec {
public:
using EnvFnsType = EnvFns;
using Config = decltype(ConcatDict(common_config, EnvFns::DefaultConfig()));
using ConfigKeys = typename Config::Keys;
using ConfigValues = typename Config::Values;
using StateSpec = decltype(ConcatDict(
common_state_spec, EnvFns::StateSpec(std::declval<Config>())));
using ActionSpec = decltype(ConcatDict(
common_action_spec, EnvFns::ActionSpec(std::declval<Config>())));
using StateKeys = typename StateSpec::Keys;
using ActionKeys = typename ActionSpec::Keys;
// For C++
Config config;
StateSpec state_spec;
ActionSpec action_spec;
static inline const Config kDefaultConfig =
ConcatDict(common_config, EnvFns::DefaultConfig());
EnvSpec() : EnvSpec(kDefaultConfig) {}
explicit EnvSpec(const ConfigValues& conf)
: config(conf),
state_spec(ConcatDict(common_state_spec, EnvFns::StateSpec(config))),
action_spec(
ConcatDict(common_action_spec, EnvFns::ActionSpec(config))) {
if (config["batch_size"_] > config["num_envs"_]) {
throw std::invalid_argument(
"It is required that batch_size <= num_envs, got num_envs = " +
std::to_string(config["num_envs"_]) +
", batch_size = " + std::to_string(config["batch_size"_]));
}
if (config["batch_size"_] == 0) {
config["batch_size"_] = config["num_envs"_];
}
}
};
#endif // YGOENV_CORE_ENV_SPEC_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_ENVPOOL_H_
#define YGOENV_CORE_ENVPOOL_H_
#include <utility>
#include <vector>
#include "ygoenv/core/env_spec.h"
/**
* Templated subclass of EnvPool, to be overrided by the real EnvPool.
*/
template <typename EnvSpec>
class EnvPool {
public:
EnvSpec spec;
using Spec = EnvSpec;
using State = NamedVector<typename EnvSpec::StateKeys, std::vector<Array>>;
using Action = NamedVector<typename EnvSpec::ActionKeys, std::vector<Array>>;
explicit EnvPool(EnvSpec spec) : spec(std::move(spec)) {}
virtual ~EnvPool() = default;
protected:
virtual void Send(const std::vector<Array>& action) {
throw std::runtime_error("send not implemented");
}
virtual void Send(std::vector<Array>&& action) {
throw std::runtime_error("send not implemented");
}
virtual std::vector<Array> Recv() {
throw std::runtime_error("recv not implemented");
}
virtual void Reset(const Array& env_ids) {
throw std::runtime_error("reset not implemented");
}
};
#endif // YGOENV_CORE_ENVPOOL_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_PY_ENVPOOL_H_
#define YGOENV_CORE_PY_ENVPOOL_H_
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <exception>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ygoenv/core/envpool.h"
namespace py = pybind11;
/**
* Convert Array to py::array, with py::capsule
*/
template <typename dtype>
struct ArrayToNumpyHelper {
static py::array Convert(const Array& a) {
auto* ptr = new std::shared_ptr<char>(a.SharedPtr());
auto capsule = py::capsule(ptr, [](void* ptr) {
delete reinterpret_cast<std::shared_ptr<char>*>(ptr);
});
return py::array(a.Shape(), reinterpret_cast<dtype*>(a.Data()), capsule);
}
};
template <typename dtype>
struct ArrayToNumpyHelper<Container<dtype>> {
using UniquePtr = Container<dtype>;
static py::array Convert(const Array& a) {
auto* ptr_arr = reinterpret_cast<UniquePtr*>(a.Data());
auto* ptr =
new std::unique_ptr<py::object[]>(new py::object[a.size]); // NOLINT
auto capsule = py::capsule(ptr, [](void* ptr) {
delete reinterpret_cast<std::unique_ptr<py::object[]>*>(ptr); // NOLINT
});
for (std::size_t i = 0; i < a.size; ++i) {
auto* inner_ptr = new UniquePtr(std::move(ptr_arr[i]));
(ptr_arr + i)->~UniquePtr();
auto capsule = py::capsule(inner_ptr, [](void* inner_ptr) {
delete reinterpret_cast<UniquePtr*>(inner_ptr);
});
if (*inner_ptr == nullptr) {
(*ptr)[i] = py::none();
} else {
(*ptr)[i] =
py::array((*inner_ptr)->Shape(),
reinterpret_cast<dtype*>((*inner_ptr)->Data()), capsule);
}
}
return {py::dtype("object"), a.Shape(),
reinterpret_cast<py::object*>(ptr->get()), capsule};
}
};
template <typename dtype>
Array NumpyToArray(const py::array& arr) {
using ArrayT = py::array_t<dtype, py::array::c_style | py::array::forcecast>;
ArrayT arr_t(arr);
ShapeSpec spec(arr_t.itemsize(),
std::vector<int>(arr_t.shape(), arr_t.shape() + arr_t.ndim()));
return {spec, reinterpret_cast<char*>(arr_t.mutable_data())};
}
template <typename dtype>
Array NumpyToArrayIncRef(const py::array& arr) {
using ArrayT = py::array_t<dtype, py::array::c_style | py::array::forcecast>;
auto* arr_ptr = new ArrayT(arr);
ShapeSpec spec(
arr_ptr->itemsize(),
std::vector<int>(arr_ptr->shape(), arr_ptr->shape() + arr_ptr->ndim()));
return Array(spec, reinterpret_cast<char*>(arr_ptr->mutable_data()),
[arr_ptr](char* p) {
py::gil_scoped_acquire acquire;
delete arr_ptr;
});
}
template <typename Spec>
struct SpecTupleHelper {
static decltype(auto) Make(const Spec& spec) {
return std::make_tuple(py::dtype::of<typename Spec::dtype>(), spec.shape,
spec.bounds, spec.elementwise_bounds);
}
};
/**
* For Container type, it is converted a numpy array of numpy array.
* The spec itself describes the shape of the outer array, the inner_spec
* contains the spec of the inner array.
* Therefore the shape returned to python side has the format
* (outer_shape, inner_shape).
*/
template <typename dtype>
struct SpecTupleHelper<Spec<Container<dtype>>> {
static decltype(auto) Make(const Spec<Container<dtype>>& spec) {
return std::make_tuple(py::dtype::of<dtype>(),
std::make_tuple(spec.shape, spec.inner_spec.shape),
spec.inner_spec.bounds,
spec.inner_spec.elementwise_bounds);
}
};
template <typename... Spec>
decltype(auto) ExportSpecs(const std::tuple<Spec...>& specs) {
return std::apply(
[&](auto&&... spec) {
return std::make_tuple(SpecTupleHelper<Spec>::Make(spec)...);
},
specs);
}
template <typename EnvSpec>
class PyEnvSpec : public EnvSpec {
public:
using StateSpecT =
decltype(ExportSpecs(std::declval<typename EnvSpec::StateSpec>()));
using ActionSpecT =
decltype(ExportSpecs(std::declval<typename EnvSpec::ActionSpec>()));
StateSpecT py_state_spec;
ActionSpecT py_action_spec;
typename EnvSpec::ConfigValues py_config_values;
static std::vector<std::string> py_config_keys;
static std::vector<std::string> py_state_keys;
static std::vector<std::string> py_action_keys;
static typename EnvSpec::ConfigValues py_default_config_values;
explicit PyEnvSpec(const typename EnvSpec::ConfigValues& conf)
: EnvSpec(conf),
py_state_spec(ExportSpecs(EnvSpec::state_spec)),
py_action_spec(ExportSpecs(EnvSpec::action_spec)),
py_config_values(EnvSpec::config.AllValues()) {}
};
template <typename EnvSpec>
std::vector<std::string> PyEnvSpec<EnvSpec>::py_config_keys =
EnvSpec::Config::AllKeys();
template <typename EnvSpec>
std::vector<std::string> PyEnvSpec<EnvSpec>::py_state_keys =
EnvSpec::StateSpec::AllKeys();
template <typename EnvSpec>
std::vector<std::string> PyEnvSpec<EnvSpec>::py_action_keys =
EnvSpec::ActionSpec::AllKeys();
template <typename EnvSpec>
typename EnvSpec::ConfigValues PyEnvSpec<EnvSpec>::py_default_config_values =
EnvSpec::kDefaultConfig.AllValues();
/**
* Bind specs to arrs, and return py::array in ret
*/
template <typename... Spec>
void ToNumpy(const std::vector<Array>& arrs, const std::tuple<Spec...>& specs,
std::vector<py::array>* ret) {
std::size_t index = 0;
std::apply(
[&](auto&&... spec) {
(ret->emplace_back(
ArrayToNumpyHelper<typename Spec::dtype>::Convert(arrs[index++])),
...);
},
specs);
}
template <typename... Spec>
void ToArray(const std::vector<py::array>& py_arrs,
const std::tuple<Spec...>& specs, std::vector<Array>* ret) {
std::size_t index = 0;
std::apply(
[&](auto&&... spec) {
(ret->emplace_back(
NumpyToArrayIncRef<typename Spec::dtype>(py_arrs[index++])),
...);
},
specs);
}
/**
* Templated subclass of EnvPool,
* to be overrided by the real EnvPool.
*/
template <typename EnvPool>
class PyEnvPool : public EnvPool {
public:
using PySpec = PyEnvSpec<typename EnvPool::Spec>;
PySpec py_spec;
static std::vector<std::string> py_state_keys;
static std::vector<std::string> py_action_keys;
explicit PyEnvPool(const PySpec& py_spec)
: EnvPool(py_spec), py_spec(py_spec) {}
/**
* py api
*/
void PySend(const std::vector<py::array>& action) {
std::vector<Array> arr;
arr.reserve(action.size());
ToArray(action, py_spec.action_spec, &arr);
py::gil_scoped_release release;
EnvPool::Send(arr); // delegate to the c++ api
}
/**
* py api
*/
std::vector<py::array> PyRecv() {
std::vector<Array> arr;
{
py::gil_scoped_release release;
arr = EnvPool::Recv();
DCHECK_EQ(arr.size(), std::tuple_size_v<typename EnvPool::State::Keys>);
}
std::vector<py::array> ret;
ret.reserve(EnvPool::State::kSize);
ToNumpy(arr, py_spec.state_spec, &ret);
return ret;
}
/**
* py api
*/
void PyReset(const py::array& env_ids) {
// PyArray arr = PyArray::From<int>(env_ids);
auto arr = NumpyToArrayIncRef<int>(env_ids);
py::gil_scoped_release release;
EnvPool::Reset(arr);
}
};
template <typename EnvPool>
std::vector<std::string> PyEnvPool<EnvPool>::py_state_keys =
PyEnvPool<EnvPool>::PySpec::py_state_keys;
template <typename EnvPool>
std::vector<std::string> PyEnvPool<EnvPool>::py_action_keys =
PyEnvPool<EnvPool>::PySpec::py_action_keys;
py::object abc_meta = py::module::import("abc").attr("ABCMeta");
/**
* Call this macro in the translation unit of each envpool instance
* It will register the envpool instance to the registry.
* The static bool status is local to the translation unit.
*/
#define REGISTER(MODULE, SPEC, ENVPOOL) \
py::class_<SPEC>(MODULE, "_" #SPEC, py::metaclass(abc_meta)) \
.def(py::init<const typename SPEC::ConfigValues&>()) \
.def_readonly("_config_values", &SPEC::py_config_values) \
.def_readonly("_state_spec", &SPEC::py_state_spec) \
.def_readonly("_action_spec", &SPEC::py_action_spec) \
.def_readonly_static("_state_keys", &SPEC::py_state_keys) \
.def_readonly_static("_action_keys", &SPEC::py_action_keys) \
.def_readonly_static("_config_keys", &SPEC::py_config_keys) \
.def_readonly_static("_default_config_values", \
&SPEC::py_default_config_values); \
py::class_<ENVPOOL>(MODULE, "_" #ENVPOOL, py::metaclass(abc_meta)) \
.def(py::init<const SPEC&>()) \
.def_readonly("_spec", &ENVPOOL::py_spec) \
.def("_recv", &ENVPOOL::PyRecv) \
.def("_send", &ENVPOOL::PySend) \
.def("_reset", &ENVPOOL::PyReset) \
.def_readonly_static("_state_keys", &ENVPOOL::py_state_keys) \
.def_readonly_static("_action_keys", \
&ENVPOOL::py_action_keys); \
#endif // YGOENV_CORE_PY_ENVPOOL_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_SPEC_H_
#define YGOENV_CORE_SPEC_H_
#include <glog/logging.h>
#include <cstddef>
#include <functional>
#include <limits>
#include <memory>
#include <numeric>
#include <tuple>
#include <utility>
#include <vector>
static std::size_t Prod(const std::size_t* shape, std::size_t ndim) {
return std::accumulate(shape, shape + ndim, static_cast<std::size_t>(1),
std::multiplies<>());
}
class ShapeSpec {
public:
int element_size;
std::vector<int> shape;
ShapeSpec() = default;
ShapeSpec(int element_size, std::vector<int> shape_vec)
: element_size(element_size), shape(std::move(shape_vec)) {}
[[nodiscard]] ShapeSpec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return {element_size, std::move(new_shape)};
}
[[nodiscard]] std::vector<std::size_t> Shape() const {
auto s = std::vector<std::size_t>(shape.size());
for (std::size_t i = 0; i < shape.size(); ++i) {
s[i] = shape[i];
}
return s;
}
};
template <typename D>
class Spec : public ShapeSpec {
public:
using dtype = D; // NOLINT
std::tuple<dtype, dtype> bounds = {std::numeric_limits<dtype>::min(),
std::numeric_limits<dtype>::max()};
std::tuple<std::vector<dtype>, std::vector<dtype>> elementwise_bounds;
explicit Spec(std::vector<int>&& shape)
: ShapeSpec(sizeof(dtype), std::move(shape)) {}
explicit Spec(const std::vector<int>& shape)
: ShapeSpec(sizeof(dtype), shape) {}
/* init with constant bounds */
Spec(std::vector<int>&& shape, std::tuple<dtype, dtype>&& bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {}
Spec(const std::vector<int>& shape, const std::tuple<dtype, dtype>& bounds)
: ShapeSpec(sizeof(dtype), shape), bounds(bounds) {}
/* init with elementwise bounds */
Spec(std::vector<int>&& shape,
std::tuple<std::vector<dtype>, std::vector<dtype>>&& elementwise_bounds)
: ShapeSpec(sizeof(dtype), std::move(shape)),
elementwise_bounds(std::move(elementwise_bounds)) {}
Spec(const std::vector<int>& shape,
const std::tuple<std::vector<dtype>, std::vector<dtype>>&
elementwise_bounds)
: ShapeSpec(sizeof(dtype), shape),
elementwise_bounds(elementwise_bounds) {}
[[nodiscard]] Spec Batch(int batch_size) const {
std::vector<int> new_shape = {batch_size};
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return Spec(std::move(new_shape));
}
};
template <typename dtype>
class TArray;
template <typename dtype>
using Container = std::unique_ptr<TArray<dtype>>;
template <typename D>
class Spec<Container<D>> : public ShapeSpec {
public:
using dtype = Container<D>; // NOLINT
Spec<D> inner_spec;
explicit Spec(const std::vector<int>& shape, const Spec<D>& inner_spec)
: ShapeSpec(sizeof(Container<D>), shape), inner_spec(inner_spec) {}
explicit Spec(std::vector<int>&& shape, Spec<D>&& inner_spec)
: ShapeSpec(sizeof(Container<D>), std::move(shape)),
inner_spec(std::move(inner_spec)) {}
};
#endif // YGOENV_CORE_SPEC_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_STATE_BUFFER_H_
#define YGOENV_CORE_STATE_BUFFER_H_
#ifndef MOODYCAMEL_DELETE_FUNCTION
#define MOODYCAMEL_DELETE_FUNCTION = delete
#endif
#include <atomic>
#include <functional>
#include <utility>
#include <vector>
#include "ygoenv/core/array.h"
#include "ygoenv/core/dict.h"
#include "ygoenv/core/spec.h"
#include "concurrentqueue/moodycamel/lightweightsemaphore.h"
/**
* Buffer of a batch of states, which is used as an intermediate storage device
* for the environments to write their state outputs of each step.
* There's a quota for how many envs' results are stored in this buffer,
* which is controlled by the batch argments in the constructor.
*/
class StateBuffer {
protected:
std::size_t batch_;
std::size_t max_num_players_;
std::vector<Array> arrays_;
std::vector<bool> is_player_state_;
std::atomic<uint64_t> offsets_{0};
std::atomic<std::size_t> alloc_count_{0};
std::atomic<std::size_t> done_count_{0};
moodycamel::LightweightSemaphore sem_;
public:
/**
* Return type of StateBuffer.Allocate is a slice of each state arrays that
* can be written by the caller. When writing is done, the caller should
* invoke done write.
*/
struct WritableSlice {
std::vector<Array> arr;
std::function<void()> done_write;
};
/**
* Create a StateBuffer instance with the player_specs and shared_specs
* provided.
*/
StateBuffer(std::size_t batch, std::size_t max_num_players,
const std::vector<ShapeSpec>& specs,
std::vector<bool> is_player_state)
: batch_(batch),
max_num_players_(max_num_players),
arrays_(MakeArray(specs)),
is_player_state_(std::move(is_player_state)) {}
/**
* Tries to allocate a piece of memory without lock.
* If this buffer runs out of quota, an out_of_range exception is thrown.
* Externally, caller has to catch the exception and handle accordingly.
*/
WritableSlice Allocate(std::size_t num_players, int order = -1) {
DCHECK_LE(num_players, max_num_players_);
std::size_t alloc_count = alloc_count_.fetch_add(1);
if (alloc_count < batch_) {
// Make a increment atomically on two uint32_t simultaneously
// This avoids lock
uint64_t increment = static_cast<uint64_t>(num_players) << 32 | 1;
uint64_t offsets = offsets_.fetch_add(increment);
uint32_t player_offset = offsets >> 32;
uint32_t shared_offset = offsets;
DCHECK_LE((std::size_t)shared_offset + 1, batch_);
DCHECK_LE((std::size_t)(player_offset + num_players),
batch_ * max_num_players_);
if (order != -1 && max_num_players_ == 1) {
// single player with sync setting: return ordered data
player_offset = shared_offset = order;
}
std::vector<Array> state;
state.reserve(arrays_.size());
for (std::size_t i = 0; i < arrays_.size(); ++i) {
const Array& a = arrays_[i];
if (is_player_state_[i]) {
state.emplace_back(
a.Slice(player_offset, player_offset + num_players));
} else {
state.emplace_back(a[shared_offset]);
}
}
return WritableSlice{.arr = std::move(state),
.done_write = [this]() { Done(); }};
}
DLOG(INFO) << "Allocation failed, continue to the next block of memory";
throw std::out_of_range("StateBuffer out of storage");
}
[[nodiscard]] std::pair<uint32_t, uint32_t> Offsets() const {
uint32_t player_offset = offsets_ >> 32;
uint32_t shared_offset = offsets_;
return {player_offset, shared_offset};
}
/**
* When the allocated memory has been filled, the user of the memory will
* call this callback to notify StateBuffer that its part has been written.
*/
void Done(std::size_t num = 1) {
std::size_t done_count = done_count_.fetch_add(num);
if (done_count + num == batch_) {
sem_.signal();
}
}
/**
* Blocks until the entire buffer is ready, aka, all quota has been
* distributed out, and all user has called done.
*/
std::vector<Array> Wait(std::size_t additional_done_count = 0) {
if (additional_done_count > 0) {
Done(additional_done_count);
}
while (!sem_.wait()) {
}
// when things are all done, compact the buffer.
uint64_t offsets = offsets_;
uint32_t player_offset = (offsets >> 32);
uint32_t shared_offset = offsets;
DCHECK_EQ((std::size_t)shared_offset, batch_ - additional_done_count);
std::vector<Array> ret;
ret.reserve(arrays_.size());
for (std::size_t i = 0; i < arrays_.size(); ++i) {
const Array& a = arrays_[i];
if (is_player_state_[i]) {
ret.emplace_back(a.Truncate(player_offset));
} else {
ret.emplace_back(a.Truncate(shared_offset));
}
}
return ret;
}
};
#endif // YGOENV_CORE_STATE_BUFFER_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_STATE_BUFFER_QUEUE_H_
#define YGOENV_CORE_STATE_BUFFER_QUEUE_H_
#include <algorithm>
#include <cstdint>
#include <memory>
#include <thread>
#include <utility>
#include <vector>
#include "ygoenv/core/array.h"
#include "ygoenv/core/circular_buffer.h"
#include "ygoenv/core/spec.h"
#include "ygoenv/core/state_buffer.h"
class StateBufferQueue {
protected:
std::size_t batch_;
std::size_t max_num_players_;
std::vector<bool> is_player_state_;
std::vector<ShapeSpec> specs_;
std::size_t queue_size_;
std::vector<std::unique_ptr<StateBuffer>> queue_;
std::atomic<uint64_t> alloc_count_, done_ptr_, alloc_tail_;
// Create stock statebuffers in a background thread
CircularBuffer<std::unique_ptr<StateBuffer>> stock_buffer_;
std::vector<std::thread> create_buffer_thread_;
std::atomic<bool> quit_;
public:
StateBufferQueue(std::size_t batch_env, std::size_t num_envs,
std::size_t max_num_players,
const std::vector<ShapeSpec>& specs)
: batch_(batch_env),
max_num_players_(max_num_players),
is_player_state_(Transform(specs,
[](const ShapeSpec& s) {
return (!s.shape.empty() &&
s.shape[0] == -1);
})),
specs_(Transform(specs,
[=](ShapeSpec s) {
if (!s.shape.empty() && s.shape[0] == -1) {
// If first dim is num_players
s.shape[0] = batch_ * max_num_players_;
return s;
}
return s.Batch(batch_);
})),
// two times enough buffer for all the envs
queue_size_((num_envs / batch_env + 2) * 2),
queue_(queue_size_), // circular buffer
alloc_count_(0),
done_ptr_(0),
stock_buffer_((num_envs / batch_env + 2) * 2),
quit_(false) {
// Only initialize first half of the buffer
// At the consumption of each block, the first consumping thread
// will allocate a new state buffer and append to the tail.
// alloc_tail_ = num_envs / batch_env + 2;
for (auto& q : queue_) {
q = std::make_unique<StateBuffer>(batch_, max_num_players_, specs_,
is_player_state_);
}
std::size_t processor_count = std::thread::hardware_concurrency();
// hardcode here :(
std::size_t create_buffer_thread_num = std::max(1UL, processor_count / 64);
for (std::size_t i = 0; i < create_buffer_thread_num; ++i) {
create_buffer_thread_.emplace_back(std::thread([&]() {
while (true) {
stock_buffer_.Put(std::make_unique<StateBuffer>(
batch_, max_num_players_, specs_, is_player_state_));
if (quit_) {
break;
}
}
}));
}
}
~StateBufferQueue() {
// stop the thread
quit_ = true;
for (std::size_t i = 0; i < create_buffer_thread_.size(); ++i) {
stock_buffer_.Get();
}
for (auto& t : create_buffer_thread_) {
t.join();
}
}
/**
* Allocate slice of memory for the current env to write.
* This function is used from the producer side.
* It is safe to access from multiple threads.
*/
StateBuffer::WritableSlice Allocate(std::size_t num_players, int order = -1) {
std::size_t pos = alloc_count_.fetch_add(1);
std::size_t offset = (pos / batch_) % queue_size_;
// if (pos % batch_ == 0) {
// // At the time a new statebuffer is accessed, the first visitor
// allocate
// // a new state buffer and put it at the back of the queue.
// std::size_t insert_pos = alloc_tail_.fetch_add(1);
// std::size_t insert_offset = insert_pos % queue_size_;
// queue_[insert_offset].reset(
// new StateBuffer(batch_, max_num_players_, specs_,
// is_player_state_));
// }
return queue_[offset]->Allocate(num_players, order);
}
/**
* Wait for the state buffer at the head to be ready.
* This function can only be accessed from one thread.
*
* BIG CAVEATE:
* Wait should be accessed from only one thread.
* If Wait is accessed from multiple threads, it is only safe if the finish
* time of each state buffer is in the same order as the allocation time.
*/
std::vector<Array> Wait(std::size_t additional_done_count = 0) {
std::unique_ptr<StateBuffer> newbuf = stock_buffer_.Get();
std::size_t pos = done_ptr_.fetch_add(1);
std::size_t offset = pos % queue_size_;
auto arr = queue_[offset]->Wait(additional_done_count);
if (additional_done_count > 0) {
// move pointer to the next block
alloc_count_.fetch_add(additional_done_count);
}
std::swap(queue_[offset], newbuf);
return arr;
}
};
#endif // YGOENV_CORE_STATE_BUFFER_QUEUE_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_TUPLE_UTILS_H_
#define YGOENV_CORE_TUPLE_UTILS_H_
#include <functional>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
template <class T, class Tuple>
struct Index;
template <class T, class... Types>
struct Index<T, std::tuple<T, Types...>> {
static constexpr std::size_t kValue = 0;
};
template <class T, class U, class... Types>
struct Index<T, std::tuple<U, Types...>> {
static constexpr std::size_t kValue =
1 + Index<T, std::tuple<Types...>>::kValue;
};
template <class F, class K, class V, std::size_t... I>
decltype(auto) ApplyZip(F&& f, K&& k, V&& v,
std::index_sequence<I...> /*unused*/) {
return std::invoke(std::forward<F>(f),
std::make_tuple(I, std::get<I>(std::forward<K>(k)),
std::get<I>(std::forward<V>(v)))...);
}
template <typename... T>
using tuple_cat_t = decltype(std::tuple_cat(std::declval<T>()...)); // NOLINT
template <typename TupleType, typename T, std::size_t... Is>
decltype(auto) TupleFromVectorImpl(std::index_sequence<Is...> /*unused*/,
const std::vector<T>& arguments) {
return TupleType(arguments[Is]...);
}
template <typename TupleType, typename T, std::size_t... Is>
decltype(auto) TupleFromVectorImpl(std::index_sequence<Is...> /*unused*/,
std::vector<T>&& arguments) {
return TupleType(std::move(arguments[Is])...);
}
template <typename TupleType, typename V>
decltype(auto) TupleFromVector(V&& arguments) {
return TupleFromVectorImpl<TupleType>(
std::make_index_sequence<std::tuple_size_v<TupleType>>{},
std::forward<V>(arguments));
}
#endif // YGOENV_CORE_TUPLE_UTILS_H_
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef YGOENV_CORE_TYPE_UTILS_H_
#define YGOENV_CORE_TYPE_UTILS_H_
#include <functional>
#include <tuple>
#include <type_traits>
#include <vector>
template <class T, class TupleTs>
struct any_match;
template <class T, template <typename...> class Tuple, typename... Ts>
struct any_match<T, Tuple<Ts...>> : std::disjunction<std::is_same<T, Ts>...> {};
template <class T, class TupleTs>
struct all_match;
template <class T, template <typename...> class Tuple, typename... Ts>
struct all_match<T, Tuple<Ts...>> : std::conjunction<std::is_same<T, Ts>...> {};
template <class To, class TupleTs>
struct all_convertible;
template <class To, template <typename...> class Tuple, typename... Fs>
struct all_convertible<To, Tuple<Fs...>>
: std::conjunction<std::is_convertible<Fs, To>...> {};
template <typename T>
constexpr bool is_tuple_v = false; // NOLINT
template <typename... types>
constexpr bool is_tuple_v<std::tuple<types...>> = true; // NOLINT
template <typename T>
constexpr bool is_vector_v = false; // NOLINT
template <typename VT>
constexpr bool is_vector_v<std::vector<VT>> = true; // NOLINT
#endif // YGOENV_CORE_TYPE_UTILS_H_
#!/usr/bin/env python3
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dummy Env in EnvPool."""
from ygoenv.python.api import py_env
from .dummy_envpool import _DummyEnvPool, _DummyEnvSpec
DummyEnvSpec, DummyDMEnvPool, DummyGymEnvPool, DummyGymnasiumEnvPool = py_env(
_DummyEnvSpec, _DummyEnvPool
)
__all__ = [
"DummyEnvSpec",
"DummyDMEnvPool",
"DummyGymEnvPool",
"DummyGymnasiumEnvPool",
]
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ygoenv/dummy/dummy_envpool.h"
#include "ygoenv/core/py_envpool.h"
/**
* Wrap the `DummyEnvSpec` and `DummyEnvPool` with the corresponding `PyEnvSpec`
* and `PyEnvPool` template.
*/
using DummyEnvSpec = PyEnvSpec<dummy::DummyEnvSpec>;
using DummyEnvPool = PyEnvPool<dummy::DummyEnvPool>;
/**
* Finally, call the REGISTER macro to expose them to python
*/
PYBIND11_MODULE(dummy_ygoenv, m) { REGISTER(m, DummyEnvSpec, DummyEnvPool) }
// Copyright 2021 Garena Online Private Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ENVPOOL_DUMMY_DUMMY_ENVPOOL_H_
#define ENVPOOL_DUMMY_DUMMY_ENVPOOL_H_
#include <memory>
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
namespace dummy {
class DummyEnvFns {
public:
/**
* Returns a dict, keys are the names of the configurable variables of this
* env, values stores the default values of the corresponding variable.
*
* EnvPool will append to your configuration some common fields, currently
* there are the envpool specific configurations that defines the behavior of
* envpool.
*
* 1. num_envs: number of envs to be launched in envpool
* 2. batch_size: the batch_size when interacting with the envpool
* 3. num_threads: the number of threads to run all the envs
* 4. thread_affinity_offset: sets the thread affinity of the threads
* 5. base_path: contains the path of the envpool python package
* 6. seed: random seed
*
* These's also single env specific configurations
*
* 7. max_num_players: defines the number of players in a single env.
*
*/
static decltype(auto) DefaultConfig() {
return MakeDict("state_num"_.Bind(10), "action_num"_.Bind(6));
}
/**
* Returns a dict, keys are the names of the states of this env,
* values are the ArraySpec of the state (as each state is stored in an
* array).
*
* The array spec can be created by calling `Spec<dtype>(shape, bounds)`.
*
* Similarly, envpool also append to this state spec, there're:
*
* 1. info:env_id: a int array that has shape [batch_size], when there's a
* batch of states, it tells the user from which `env_id` that these states
* come from.
* 2. info:players.env_id: This is similar to `env_id`, but it has a shape of
* [total_num_player], where the `total_num_player` is the total number of
* players summed.
*
* For example, if in one batch we have states from envs [1, 3, 4],
* in env 1 there're players [1, 2], in env 2 there're players [2, 3, 4],
* in env 3 there're players [1]. Then:
* `info:env_id == [1, 3, 4]`
* `info:players.env_id == [1, 1, 3, 3, 3, 4]`
*
* 3. elapsed_step: the total elapsed steps of the envs.
* 4. done: whether it is the end of episode for each env.
*/
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:raw"_.Bind(Spec<int>({-1, conf["state_num"_]})),
"obs:dyn"_.Bind(Spec<Container<int>>(
{-1}, Spec<int>({-1, conf["state_num"_]}))),
"info:players.done"_.Bind(Spec<bool>({-1})),
"info:players.id"_.Bind(
Spec<int>({-1}, {0, conf["max_num_players"_]})));
}
/**
* Returns a dict, keys are the names of the actions of this env,
* values are the ArraySpec of the actions (each action is stored in an
* array).
*
* Similarly, envpool also append to this state spec, there're:
*
* 1. env_id
* 2. players.env_id
*
* Their meanings are the same as in the `StateSpec`.
*/
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("list_action"_.Bind(Spec<double>({6})),
"players.action"_.Bind(Spec<int>({-1})),
"players.id"_.Bind(Spec<int>({-1})));
}
};
/**
* Create an DummyEnvSpec by passing the above functions to EnvSpec.
*/
using DummyEnvSpec = EnvSpec<DummyEnvFns>;
/**
* The main part of the single env.
* It inherits and implements the interfaces defined in Env specialized by the
* DummyEnvSpec we defined above.
*/
class DummyEnv : public Env<DummyEnvSpec> {
protected:
int state_{0};
public:
/**
* Initilize the env, in this function we perform tasks like loading the game
* rom etc.
*/
DummyEnv(const Spec& spec, int env_id) : Env<DummyEnvSpec>(spec, env_id) {
if (seed_ < 1) {
seed_ = 1;
}
}
/**
* Reset this single env, this has the same meaning as the openai gym's reset
* The reset function usually returns the state after reset, here, we first
* call `Allocate` to create the state (which is managed by envpool), and
* populate it with the returning state.
*/
void Reset() override {
state_ = 0;
int num_players =
max_num_players_ <= 1 ? 1 : state_ % (max_num_players_ - 1) + 1;
// Ask envpool to allocate a piece of memory where we can write the state
// after reset.
auto state = Allocate(num_players);
// write the information of the next state into the state.
for (int i = 0; i < num_players; ++i) {
state["info:players.id"_][i] = i;
state["info:players.done"_][i] = IsDone();
state["obs:raw"_](i, 0) = state_;
state["obs:raw"_](i, 1) = 0;
state["reward"_][i] = -i;
// dynamic array
Container<int>& dyn = state["obs:dyn"_][i];
// new spec
auto dyn_spec = ::Spec<int>({env_id_ + 1, spec_.config["state_num"_]});
// use this spec to create an array
auto* array = new TArray<int>(dyn_spec);
// perform some normal array writing
array->Fill(env_id_);
// finally pass it to dynamic array
dyn.reset(array);
}
}
/**
* Step is the central function of a single env.
* It takes an action, executes the env, and returns the next state.
*
* Similar to Reset, Step also return the state through `Allocate` function.
*
*/
void Step(const Action& action) override {
++state_;
int num_players =
max_num_players_ <= 1 ? 1 : state_ % (max_num_players_ - 1) + 1;
// Parse the action, and execute the env (dummy env has nothing to do)
int action_num = action["players.env_id"_].Shape(0);
for (int i = 0; i < action_num; ++i) {
if (static_cast<int>(action["players.env_id"_][i]) != env_id_) {
action_num = 0;
}
}
// Check if actions can successfully pass into envpool
double x = action["list_action"_][0];
for (int i = 0; i < 6; ++i) {
double y = action["list_action"_][i];
CHECK_EQ(x, y);
}
// Ask envpool to allocate a piece of memory where we can write the state
// after reset.
auto state = Allocate(num_players);
// write the information of the next state into the state.
for (int i = 0; i < num_players; ++i) {
state["info:players.id"_][i] = i;
state["info:players.done"_][i] = IsDone();
state["obs:raw"_](i, 0) = state_;
state["obs:raw"_](i, 1) = action_num;
state["reward"_][i] = -i;
Container<int>& dyn = state["obs:dyn"_][i];
auto dyn_spec = ::Spec<int>({env_id_ + 1, spec_.config["state_num"_]});
dyn = std::make_unique<TArray<int>>(dyn_spec);
dyn->Fill(env_id_);
}
}
/**
* Whether the single env has ended the current episode.
*/
bool IsDone() override { return state_ >= seed_; }
};
/**
* Pass the DummyEnv we defined above as an template parameter to the
* AsyncEnvPool template, it gives us a parallelized version of the single env.
*/
using DummyEnvPool = AsyncEnvPool<DummyEnv>;
} // namespace dummy
#endif // ENVPOOL_DUMMY_DUMMY_ENVPOOL_H_
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classic control env registration."""
from ygoenv.registration import register
register(
task_id="Dummy-v0",
import_path="ygoenv.dummy",
spec_cls="DummyEnvSpec",
dm_cls="DummyDMEnvPool",
gym_cls="DummyGymEnvPool",
gymnasium_cls="DummyGymnasiumEnvPool",
)
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry point for all envs' registration."""
try:
import ygoenv.ygopro.registration # noqa: F401
except ImportError:
pass
try:
import ygoenv.dummy.registration # noqa: F401
except ImportError:
pass
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python interface for EnvPool."""
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Api wrapper layer for EnvPool."""
from typing import Tuple, Type
from .dm_envpool import DMEnvPoolMeta
from .env_spec import EnvSpecMeta
from .gym_envpool import GymEnvPoolMeta
from .gymnasium_envpool import GymnasiumEnvPoolMeta
from .protocol import EnvPool, EnvSpec
def py_env(
envspec: Type[EnvSpec], envpool: Type[EnvPool]
) -> Tuple[Type[EnvSpec], Type[EnvPool], Type[EnvPool], Type[EnvPool]]:
"""Initialize EnvPool for users."""
# remove the _ prefix added when registering cpp class via pybind
spec_name = envspec.__name__[1:]
pool_name = envpool.__name__[1:]
return (
EnvSpecMeta(spec_name, (envspec,), {}), # type: ignore[return-value]
DMEnvPoolMeta(pool_name.replace("EnvPool", "DMEnvPool"), (envpool,), {}),
GymEnvPoolMeta(pool_name.replace("EnvPool", "GymEnvPool"), (envpool,), {}),
GymnasiumEnvPoolMeta(
pool_name.replace("EnvPool", "GymnasiumEnvPool"), (envpool,), {}
),
)
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper function for data convertion."""
from collections import namedtuple
from typing import Any, Dict, List, Tuple, Type
import dm_env
import gym
import gymnasium
import numpy as np
import optree
from optree import PyTreeSpec
from .protocol import ArraySpec
ACTION_THRESHOLD = 2**20
def to_nested_dict(flatten_dict: Dict[str, Any],
generator: Type = dict) -> Dict[str, Any]:
"""Convert a flat dict to a hierarchical dict.
The input dict's hierarchy is denoted by ``.``.
Example:
::
>>> to_nested_dict({"a.b": 2333, "a.c": 666})
{"a": {"b": 2333, "c": 666}}
Args:
flatten_dict: a dict whose keys list contains ``.`` for hierarchical
representation.
generator: a type of mapping. Default to ``dict``.
"""
ret: Dict[str, Any] = generator()
for k, v in flatten_dict.items():
segments = k.split(".")
ptr = ret
for s in segments[:-1]:
if s not in ptr:
ptr[s] = generator()
ptr = ptr[s]
ptr[segments[-1]] = v
return ret
def to_namedtuple(name: str, hdict: Dict) -> Tuple:
"""Convert a hierarchical dict to namedtuple."""
return namedtuple(name, hdict.keys())(
*[
to_namedtuple(k, v) if isinstance(v, Dict) else v
for k, v in hdict.items()
]
)
def dm_spec_transform(
name: str, spec: ArraySpec, spec_type: str
) -> dm_env.specs.Array:
"""Transform ArraySpec to dm_env compatible specs."""
if np.prod(np.abs(spec.shape)) == 1 and \
np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD:
# special treatment for discrete action space
return dm_env.specs.DiscreteArray(
name=name,
dtype=spec.dtype,
num_values=int(spec.maximum - spec.minimum + 1),
)
return dm_env.specs.BoundedArray(
name=name,
shape=[s for s in spec.shape if s != -1],
dtype=spec.dtype,
minimum=spec.minimum,
maximum=spec.maximum,
)
def gym_spec_transform(name: str, spec: ArraySpec, spec_type: str) -> gym.Space:
"""Transform ArraySpec to gym.Env compatible spaces."""
if np.prod(np.abs(spec.shape)) == 1 and \
np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD:
# special treatment for discrete action space
discrete_range = int(spec.maximum - spec.minimum + 1)
try:
return gym.spaces.Discrete(n=discrete_range, start=int(spec.minimum))
except TypeError: # old gym version doesn't have `start`
return gym.spaces.Discrete(n=discrete_range)
return gym.spaces.Box(
shape=[s for s in spec.shape if s != -1],
dtype=spec.dtype,
low=spec.minimum,
high=spec.maximum,
)
def gymnasium_spec_transform(
name: str, spec: ArraySpec, spec_type: str
) -> gymnasium.Space:
"""Transform ArraySpec to gymnasium.Env compatible spaces."""
if np.prod(np.abs(spec.shape)) == 1 and \
np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD:
# special treatment for discrete action space
discrete_range = int(spec.maximum - spec.minimum + 1)
return gymnasium.spaces.Discrete(n=discrete_range, start=int(spec.minimum))
return gymnasium.spaces.Box(
shape=[s for s in spec.shape if s != -1],
dtype=spec.dtype,
low=spec.minimum,
high=spec.maximum,
)
def dm_structure(
root_name: str,
keys: List[str],
) -> Tuple[List[Tuple[int, ...]], List[int], PyTreeSpec]:
"""Convert flat keys into tree structure for namedtuple construction."""
new_keys = []
for key in keys:
if key in ["obs", "info"]: # special treatment for single-node obs/info
key = f"obs:{key}"
key = key.replace("info:", "obs:") # merge obs and info together
key = key.replace("obs:", f"{root_name}:") # compatible with to_namedtuple
new_keys.append(key.replace(":", "."))
dict_tree = to_nested_dict(dict(zip(new_keys, list(range(len(new_keys))))))
structure = to_namedtuple(root_name, dict_tree)
paths, indices, treespec = optree.tree_flatten_with_path(structure)
return paths, indices, treespec
def gym_structure(
keys: List[str]
) -> Tuple[List[Tuple[str, ...]], List[int], PyTreeSpec]:
"""Convert flat keys into tree structure for dict construction."""
keys = [k.replace(":", ".") for k in keys]
dict_tree = to_nested_dict(dict(zip(keys, list(range(len(keys))))))
paths, indices, treespec = optree.tree_flatten_with_path(dict_tree)
return paths, indices, treespec
gymnasium_structure = gym_structure
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool meta class for dm_env API."""
from abc import ABC, ABCMeta
from typing import Any, Dict, List, Tuple, Union
import dm_env
import numpy as np
import optree
from dm_env import TimeStep
from .data import dm_structure
from .envpool import EnvPoolMixin
from .utils import check_key_duplication
class DMEnvPoolMixin(ABC):
"""Special treatment for dm_env API."""
def observation_spec(self: Any) -> Tuple:
"""Observation spec from EnvSpec."""
if not hasattr(self, "_dm_observation_spec"):
self._dm_observation_spec = self.spec.observation_spec()
return self._dm_observation_spec
def action_spec(self: Any) -> Union[dm_env.specs.Array, Tuple]:
"""Action spec from EnvSpec."""
if not hasattr(self, "_dm_action_spec"):
self._dm_action_spec = self.spec.action_spec()
return self._dm_action_spec
class DMEnvPoolMeta(ABCMeta):
"""Additional wrapper for EnvPool dm_env API."""
def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any:
"""Check internal config and initialize data format convertion."""
base = parents[0]
try:
from .lax import XlaMixin
parents = (
base, DMEnvPoolMixin, EnvPoolMixin, XlaMixin, dm_env.Environment
)
except ImportError:
def _xla(self: Any) -> None:
raise RuntimeError("XLA is disabled. To enable XLA please install jax.")
attrs["xla"] = _xla
parents = (base, DMEnvPoolMixin, EnvPoolMixin, dm_env.Environment)
state_keys = base._state_keys
action_keys = base._action_keys
check_key_duplication(name, "state", state_keys)
check_key_duplication(name, "action", action_keys)
state_paths, state_idx, treepsec = dm_structure("State", state_keys)
def _to_dm(
self: Any,
state_values: List[np.ndarray],
reset: bool,
return_info: bool,
) -> TimeStep:
values = (state_values[i] for i in state_idx)
state = optree.tree_unflatten(treepsec, values)
timestep = TimeStep(
step_type=state.step_type,
observation=state.State,
reward=state.reward,
discount=state.discount,
)
return timestep
attrs["_to"] = _to_dm
subcls = super().__new__(cls, name, parents, attrs)
def init(self: Any, spec: Any) -> None:
"""Set self.spec to EnvSpecMeta."""
super(subcls, self).__init__(spec)
self.spec = spec
setattr(subcls, "__init__", init) # noqa: B010
return subcls
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvSpec mixin definition."""
import pprint
from abc import ABC, ABCMeta
from collections import namedtuple
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union
import dm_env
import gym
import gymnasium
from .data import (
dm_spec_transform,
gym_spec_transform,
gymnasium_spec_transform,
to_namedtuple,
to_nested_dict,
)
from .protocol import ArraySpec, EnvSpec
from .utils import check_key_duplication
class EnvSpecMixin(ABC):
"""Mixin class for EnvSpec, exposed to EnvSpecMeta."""
gen_config: Type
@property
def config(self: EnvSpec) -> NamedTuple:
"""Configuration used to create the current EnvSpec."""
return self.gen_config(*self._config_values)
@property
def reward_threshold(self: EnvSpec) -> Optional[float]:
"""Reward threshold, None for no threshold."""
try:
return self.config.reward_threshold # type: ignore
except AttributeError:
return None
@property
def state_array_spec(self: EnvSpec) -> Dict[str, Any]:
"""Specs of the states of the environment.
Returns:
state_spec: A dict whose keys are the names of the states,
its values is a tuple of (dtype, shape).
"""
state_spec = [ArraySpec(*s) for s in self._state_spec]
return dict(zip(self._state_keys, state_spec))
@property
def action_array_spec(self: EnvSpec) -> Dict[str, Any]:
"""Specs of the actions of the environment.
Returns:
state_spec: A dict whose keys are the names of the actions,
its values is a tuple of (dtype, shape).
"""
action_spec = [ArraySpec(*s) for s in self._action_spec]
return dict(zip(self._action_keys, action_spec))
def observation_spec(self: EnvSpec) -> Tuple:
"""Convert internal state_spec to dm_env compatible format.
Returns:
observation_spec: A namedtuple (maybe nested) that contains all keys
that start with ``obs`` or ``info`` with their corresponding specs.
"""
spec = self.state_array_spec
spec = {
k.replace("obs:", "").replace("info:", ""):
dm_spec_transform(k.replace(":", ".").split(".")[-1], v, "obs")
for k, v in spec.items()
if k.startswith("obs") or k.startswith("info")
}
return to_namedtuple("State", to_nested_dict(spec))
def action_spec(self: EnvSpec) -> Union[dm_env.specs.Array, Tuple]:
"""Convert internal action_spec to dm_env compatible format.
Returns:
action_spec: A single dm_env.specs.Array or a dict (maybe nested) that
contains all keys that start with ``action`` with their corresponding
specs.
Note:
If the original action_spec has a length of 3 ("env_id",
"players.env_id", *), it returns the last spec instead of all for
simplicity.
"""
spec = self.action_array_spec
if len(spec) == 3:
# only env_id, players.env_id, action
spec.pop("env_id")
spec.pop("players.env_id")
return dm_spec_transform(
list(spec.keys())[0],
list(spec.values())[0], "act"
)
spec = {
k: dm_spec_transform(k.split(".")[-1], v, "act") for k, v in spec.items()
}
return to_namedtuple("Action", to_nested_dict(spec))
@property
def observation_space(self: EnvSpec) -> Union[gym.Space, Dict[str, Any]]:
"""Convert internal state_spec to gym.Env compatible format.
Returns:
observation_space: A dict (maybe nested) that contains all keys
that start with ``obs`` with their corresponding specs.
Note:
If only one key starts with ``obs``, it returns that space instead of
all for simplicity.
"""
spec = self.state_array_spec
spec = {
k.replace("obs:", ""):
gym_spec_transform(k.replace(":", ".").split(".")[-1], v, "obs")
for k, v in spec.items()
if k.startswith("obs")
}
if len(spec) == 1:
return list(spec.values())[0]
return to_nested_dict(spec, gym.spaces.Dict)
@property
def action_space(self: EnvSpec) -> Union[gym.Space, Dict[str, Any]]:
"""Convert internal action_spec to gym.Env compatible format.
Returns:
action_space: A dict (maybe nested) that contains key-value paired
corresponding specs.
Note:
If the original action_spec has a length of 3 ("env_id",
"players.env_id", *), it returns the last space instead of all for
simplicity.
"""
spec = self.action_array_spec
if len(spec) == 3:
# only env_id, players.env_id, action
spec.pop("env_id")
spec.pop("players.env_id")
return gym_spec_transform(
list(spec.keys())[0],
list(spec.values())[0], "act"
)
spec = {
k: gym_spec_transform(k.split(".")[-1], v, "act") for k, v in spec.items()
}
return to_nested_dict(spec, gym.spaces.Dict)
@property
def gymnasium_observation_space(
self: EnvSpec
) -> Union[gymnasium.Space, Dict[str, Any]]:
"""Convert internal state_spec to gymnasium.Env compatible format.
Returns:
observation_space: A dict (maybe nested) that contains all keys
that start with ``obs`` with their corresponding specs.
Note:
If only one key starts with ``obs``, it returns that space instead of
all for simplicity.
"""
spec = self.state_array_spec
spec = {
k.replace("obs:", ""):
gymnasium_spec_transform(k.replace(":", ".").split(".")[-1], v, "obs")
for k, v in spec.items()
if k.startswith("obs")
}
if len(spec) == 1:
return list(spec.values())[0]
return to_nested_dict(spec, gymnasium.spaces.Dict)
@property
def gymnasium_action_space(
self: EnvSpec
) -> Union[gymnasium.Space, Dict[str, Any]]:
"""Convert internal action_spec to gymnasium.Env compatible format.
Returns:
action_space: A dict (maybe nested) that contains key-value paired
corresponding specs.
Note:
If the original action_spec has a length of 3 ("env_id",
"players.env_id", *), it returns the last space instead of all for
simplicity.
"""
spec = self.action_array_spec
if len(spec) == 3:
# only env_id, players.env_id, action
spec.pop("env_id")
spec.pop("players.env_id")
return gymnasium_spec_transform(
list(spec.keys())[0],
list(spec.values())[0], "act"
)
spec = {
k: gymnasium_spec_transform(k.split(".")[-1], v, "act")
for k, v in spec.items()
}
return to_nested_dict(spec, gymnasium.spaces.Dict)
def __repr__(self: EnvSpec) -> str:
"""Prettify debug info."""
config_info = pprint.pformat(self.config)[6:]
return f"{self.__class__.__name__}{config_info}"
class EnvSpecMeta(ABCMeta):
"""Additional checker and wrapper for EnvSpec."""
def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any:
"""Check keys and initialize namedtuple config."""
base = parents[0]
parents = (base, EnvSpecMixin)
config_keys = base._config_keys
check_key_duplication(name, "config", config_keys)
config_keys: List[str] = list(
map(lambda s: s.replace(".", "_"), config_keys)
)
defaults: Tuple = base._default_config_values
attrs["gen_config"] = namedtuple("Config", config_keys, defaults=defaults)
return super().__new__(cls, name, parents, attrs)
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool Mixin class for meta class definition."""
import pprint
import warnings
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import optree
from dm_env import TimeStep
from .protocol import EnvPool, EnvSpec
class EnvPoolMixin(ABC):
"""Mixin class for EnvPool, exposed to EnvPoolMeta."""
_spec: EnvSpec
def _check_action(self: EnvPool, actions: List[np.ndarray]) -> None:
if hasattr(self, "_check_action_finished"): # only check once
return
self._check_action_finished = True
for a, (k, v) in zip(actions, self.spec.action_array_spec.items()):
if v.dtype != a.dtype:
raise RuntimeError(
f"Expected dtype {v.dtype} with action \"{k}\", got {a.dtype}"
)
shape = tuple(v.shape)
if len(shape) > 0 and shape[0] == -1:
if a.shape[1:] != shape[1:]:
raise RuntimeError(
f"Expected shape {shape} with action \"{k}\", got {a.shape}"
)
else:
if len(a.shape) == 0 or a.shape[1:] != shape:
raise RuntimeError(
f"Expected shape {('num_env', *shape)} with action \"{k}\", "
f"got {a.shape}"
)
def _from(
self: EnvPool,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> List[np.ndarray]:
"""Convert action to C++-acceptable format."""
if isinstance(action, dict):
paths, values, _ = optree.tree_flatten_with_path(action)
adict = {'.'.join(p): v for p, v in zip(paths, values)}
else: # only 3 keys in action_keys
if not hasattr(self, "_last_action_type"):
self._last_action_type = self._spec._action_spec[-1][0]
if not hasattr(self, "_last_action_name"):
self._last_action_name = self._spec._action_keys[-1]
if isinstance(action, np.ndarray):
# else it could be a jax array, when using xla
action = action.astype(
self._last_action_type, # type: ignore
order='C',
)
adict = {self._last_action_name: action} # type: ignore
if env_id is None:
if "env_id" not in adict:
adict["env_id"] = self.all_env_ids
else:
adict["env_id"] = env_id.astype(np.int32)
if "players.env_id" not in adict:
adict["players.env_id"] = adict["env_id"]
if not hasattr(self, "_action_names"):
self._action_names = self._spec._action_keys
return list(map(lambda k: adict[k], self._action_names)) # type: ignore
def __len__(self: EnvPool) -> int:
"""Return the number of environments."""
return self.config["num_envs"]
@property
def all_env_ids(self: EnvPool) -> np.ndarray:
"""All env_id in numpy ndarray with dtype=np.int32."""
if not hasattr(self, "_all_env_ids"):
self._all_env_ids = np.arange(self.config["num_envs"], dtype=np.int32)
return self._all_env_ids # type: ignore
@property
def is_async(self: EnvPool) -> bool:
"""Return if this env is in sync mode or async mode."""
return self.config["batch_size"] > 0 and self.config[
"num_envs"] != self.config["batch_size"]
def seed(self: EnvPool, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Set the seed for all environments (abandoned)."""
warnings.warn(
"The `seed` function in envpool is abandoned. "
"You can set seed by envpool.make(..., seed=seed) instead.",
stacklevel=2
)
def send(
self: EnvPool,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> None:
"""Send actions into EnvPool."""
action = self._from(action, env_id)
self._check_action(action)
self._send(action)
def recv(
self: EnvPool,
reset: bool = False,
return_info: bool = True,
) -> Union[TimeStep, Tuple]:
"""Recv a batch state from EnvPool."""
state_list = self._recv()
return self._to(state_list, reset, return_info)
def async_reset(self: EnvPool) -> None:
"""Follows the async semantics, reset the envs in env_ids."""
self._reset(self.all_env_ids)
def step(
self: EnvPool,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> Union[TimeStep, Tuple]:
"""Perform one step with multiple environments in EnvPool."""
self.send(action, env_id)
return self.recv(reset=False, return_info=True)
def reset(
self: EnvPool,
env_id: Optional[np.ndarray] = None,
) -> Union[TimeStep, Tuple]:
"""Reset envs in env_id.
This behavior is not defined in async mode.
"""
if env_id is None:
env_id = self.all_env_ids
self._reset(env_id)
return self.recv(
reset=True, return_info=self.config["gym_reset_return_info"]
)
@property
def config(self: EnvPool) -> Dict[str, Any]:
"""Config dict of this class."""
return dict(zip(self._spec._config_keys, self._spec._config_values))
def __repr__(self: EnvPool) -> str:
"""Prettify the debug information."""
config = self.config
config_str = ", ".join(
[f"{k}={pprint.pformat(v)}" for k, v in config.items()]
)
return f"{self.__class__.__name__}({config_str})"
def __str__(self: EnvPool) -> str:
"""Prettify the debug information."""
return self.__repr__()
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool meta class for gym.Env API."""
from abc import ABC, ABCMeta
from typing import Any, Dict, List, Tuple, Union
import gym
import numpy as np
import optree
from packaging import version
from .data import gym_structure
from .envpool import EnvPoolMixin
from .utils import check_key_duplication
class GymEnvPoolMixin(ABC):
"""Special treatment for gym API."""
@property
def observation_space(self: Any) -> Union[gym.Space, Dict[str, Any]]:
"""Observation space from EnvSpec."""
if not hasattr(self, "_gym_observation_space"):
self._gym_observation_space = self.spec.observation_space
return self._gym_observation_space
@property
def action_space(self: Any) -> Union[gym.Space, Dict[str, Any]]:
"""Action space from EnvSpec."""
if not hasattr(self, "_gym_action_space"):
self._gym_action_space = self.spec.action_space
return self._gym_action_space
class GymEnvPoolMeta(ABCMeta, gym.Env.__class__):
"""Additional wrapper for EnvPool gym.Env API."""
def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any:
"""Check internal config and initialize data format convertion."""
base = parents[0]
try:
from .lax import XlaMixin
parents = (base, GymEnvPoolMixin, EnvPoolMixin, XlaMixin, gym.Env)
except ImportError:
def _xla(self: Any) -> None:
raise RuntimeError("XLA is disabled. To enable XLA please install jax.")
attrs["xla"] = _xla
parents = (base, GymEnvPoolMixin, EnvPoolMixin, gym.Env)
state_keys = base._state_keys
action_keys = base._action_keys
check_key_duplication(name, "state", state_keys)
check_key_duplication(name, "action", action_keys)
state_paths, state_idx, treepsec = gym_structure(state_keys)
new_gym_api = version.parse(gym.__version__) >= version.parse("0.26.0")
def _to_gym(
self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool
) -> Union[
Any,
Tuple[Any, Any],
Tuple[Any, np.ndarray, np.ndarray, Any],
Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any],
]:
values = (state_values[i] for i in state_idx)
state = optree.tree_unflatten(treepsec, values)
if reset and not (return_info or new_gym_api):
return state["obs"]
info = state["info"]
if not new_gym_api:
info["TimeLimit.truncated"] = state["trunc"]
info["elapsed_step"] = state["elapsed_step"]
if reset:
return state["obs"], info
if new_gym_api:
terminated = state["done"] & ~state["trunc"]
return state["obs"], state["reward"], terminated, state["trunc"], info
return state["obs"], state["reward"], state["done"], info
attrs["_to"] = _to_gym
subcls = super().__new__(cls, name, parents, attrs)
def init(self: Any, spec: Any) -> None:
"""Set self.spec to EnvSpecMeta."""
super(subcls, self).__init__(spec)
self.spec = spec
setattr(subcls, "__init__", init) # noqa: B010
return subcls
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EnvPool meta class for gymnasium.Env API."""
from abc import ABC, ABCMeta
from typing import Any, Dict, List, Tuple, Union
import gymnasium
import numpy as np
import optree
from .data import gymnasium_structure
from .envpool import EnvPoolMixin
from .utils import check_key_duplication
class GymnasiumEnvPoolMixin(ABC):
"""Special treatment for gymnasim API."""
@property
def observation_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]:
"""Observation space from EnvSpec."""
if not hasattr(self, "_gym_observation_space"):
self._gym_observation_space = self.spec.gymnasium_observation_space
return self._gym_observation_space
@property
def action_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]:
"""Action space from EnvSpec."""
if not hasattr(self, "_gym_action_space"):
self._gym_action_space = self.spec.gymnasium_action_space
return self._gym_action_space
class GymnasiumEnvPoolMeta(ABCMeta, gymnasium.Env.__class__):
"""Additional wrapper for EnvPool gymnasium.Env API."""
def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any:
"""Check internal config and initialize data format convertion."""
base = parents[0]
try:
from .lax import XlaMixin
parents = (
base, GymnasiumEnvPoolMixin, EnvPoolMixin, XlaMixin, gymnasium.Env
)
except ImportError:
def _xla(self: Any) -> None:
raise RuntimeError("XLA is disabled. To enable XLA please install jax.")
attrs["xla"] = _xla
parents = (base, GymnasiumEnvPoolMixin, EnvPoolMixin, gymnasium.Env)
state_keys = base._state_keys
action_keys = base._action_keys
check_key_duplication(name, "state", state_keys)
check_key_duplication(name, "action", action_keys)
state_paths, state_idx, treepsec = gymnasium_structure(state_keys)
def _to_gymnasium(
self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool
) -> Union[
Any,
Tuple[Any, Any],
Tuple[Any, np.ndarray, np.ndarray, Any],
Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any],
]:
values = (state_values[i] for i in state_idx)
state = optree.tree_unflatten(treepsec, values)
info = state["info"]
info["elapsed_step"] = state["elapsed_step"]
if reset:
return state["obs"], info
terminated = state["done"] & ~state["trunc"]
return state["obs"], state["reward"], terminated, state["trunc"], info
attrs["_to"] = _to_gymnasium
subcls = super().__new__(cls, name, parents, attrs)
def init(self: Any, spec: Any) -> None:
"""Set self.spec to EnvSpecMeta."""
super(subcls, self).__init__(spec)
self.spec = spec
setattr(subcls, "__init__", init) # noqa: B010
return subcls
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protocol of C++ EnvPool."""
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import dm_env
import gym
import numpy as np
from dm_env import TimeStep
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol # type: ignore
class EnvSpec(Protocol):
"""Cpp EnvSpec class."""
_config_keys: List[str]
_default_config_values: Tuple
gen_config: Type
def __init__(self, config: Tuple):
"""Protocol for constructor of EnvSpec."""
@property
def _state_spec(self) -> Tuple:
"""Cpp private _state_spec."""
@property
def _action_spec(self) -> Tuple:
"""Cpp private _action_spec."""
@property
def _state_keys(self) -> List:
"""Cpp private _state_keys."""
@property
def _action_keys(self) -> List:
"""Cpp private _action_keys."""
@property
def _config_values(self) -> Tuple:
"""Cpp private _config_values."""
@property
def config(self) -> NamedTuple:
"""Configuration used to create the current EnvSpec."""
@property
def state_array_spec(self) -> Dict[str, Any]:
"""Specs of the states of the environment in ArraySpec format."""
@property
def action_array_spec(self) -> Dict[str, Any]:
"""Specs of the actions of the environment in ArraySpec format."""
def observation_spec(self) -> Dict[str, Any]:
"""Specs of the observations of the environment in dm_env format."""
def action_spec(self) -> Union[dm_env.specs.Array, Dict[str, Any]]:
"""Specs of the actions of the environment in dm_env format."""
@property
def observation_space(self) -> Dict[str, Any]:
"""Specs of the observations of the environment in gym.Env format."""
@property
def action_space(self) -> Union[gym.Space, Dict[str, Any]]:
"""Specs of the actions of the environment in gym.Env format."""
@property
def reward_threshold(self) -> Optional[float]:
"""Reward threshold, None for no threshold."""
class ArraySpec(object):
"""Spec of numpy array."""
def __init__(
self, dtype: Type, shape: List[int], bounds: Tuple[Any, Any],
element_wise_bounds: Tuple[Any, Any]
):
"""Constructor of ArraySpec."""
self.dtype = dtype
self.shape = shape
if element_wise_bounds[0]:
self.minimum = np.array(element_wise_bounds[0])
else:
self.minimum = bounds[0]
if element_wise_bounds[1]:
self.maximum = np.array(element_wise_bounds[1])
else:
self.maximum = bounds[1]
def __repr__(self) -> str:
"""Beautify debug info."""
return (
f"ArraySpec(shape={self.shape}, dtype={self.dtype}, "
f"minimum={self.minimum}, maximum={self.maximum})"
)
class EnvPool(Protocol):
"""Cpp PyEnvpool class interface."""
_state_keys: List[str]
_action_keys: List[str]
spec: Any
def __init__(self, spec: EnvSpec):
"""Constructor of EnvPool."""
def __len__(self) -> int:
"""Return the number of environments."""
@property
def _spec(self) -> EnvSpec:
"""Cpp env spec."""
@property
def _action_spec(self) -> List:
"""Cpp action spec."""
def _check_action(self, actions: List) -> None:
"""Check action shapes."""
def _recv(self) -> List[np.ndarray]:
"""Cpp private _recv method."""
def _send(self, action: List[np.ndarray]) -> None:
"""Cpp private _send method."""
def _reset(self, env_id: np.ndarray) -> None:
"""Cpp private _reset method."""
def _from(
self,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> List[np.ndarray]:
"""Convertion for input action."""
def _to(
self,
state: List[np.ndarray],
reset: bool,
return_info: bool,
) -> Union[TimeStep, Tuple]:
"""A switch of to_dm and to_gym for output state."""
@property
def all_env_ids(self) -> np.ndarray:
"""All env_id in numpy ndarray with dtype=np.int32."""
@property
def is_async(self) -> bool:
"""Return if this env is in sync mode or async mode."""
@property
def observation_space(self) -> Union[gym.Space, Dict[str, Any]]:
"""Gym observation space."""
@property
def action_space(self) -> Union[gym.Space, Dict[str, Any]]:
"""Gym action space."""
def observation_spec(self) -> Tuple:
"""Dm observation spec."""
def action_spec(self) -> Union[dm_env.specs.Array, Tuple]:
"""Dm action spec."""
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Set the seed for all environments."""
@property
def config(self) -> Dict[str, Any]:
"""Envpool config."""
def send(
self,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> None:
"""Envpool send wrapper."""
def recv(
self,
reset: bool = False,
return_info: bool = True,
) -> Union[TimeStep, Tuple]:
"""Envpool recv wrapper."""
def async_reset(self) -> None:
"""Envpool async reset interface."""
def step(
self,
action: Union[Dict[str, Any], np.ndarray],
env_id: Optional[np.ndarray] = None,
) -> Union[TimeStep, Tuple]:
"""Envpool step interface that performs send/recv."""
def reset(
self,
env_id: Optional[np.ndarray] = None,
) -> Union[TimeStep, Tuple]:
"""Envpool reset interface."""
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper function for Python API."""
from typing import Any, List
import numpy as np
def check_key_duplication(cls: Any, keytype: str, keys: List[str]) -> None:
"""Check if there's any duplicated keys in ``keys``."""
ukeys, counts = np.unique(keys, return_counts=True)
if not np.all(counts == 1):
dup_keys = ukeys[counts > 1]
raise SystemError(
f"{cls} c++ code error. {keytype} keys {list(dup_keys)} are duplicated. "
f"Please report to the author of {cls}."
)
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global env registry."""
import importlib
import os
from typing import Any, Dict, List, Tuple
import gym
from packaging import version
base_path = os.path.abspath(os.path.dirname(__file__))
class EnvRegistry:
"""A collection of available envs."""
def __init__(self) -> None:
"""Constructor of EnvRegistry."""
self.specs: Dict[str, Tuple[str, str, Dict[str, Any]]] = {}
self.envpools: Dict[str, Dict[str, Tuple[str, str]]] = {}
def register(
self, task_id: str, import_path: str, spec_cls: str, dm_cls: str,
gym_cls: str, gymnasium_cls: str, **kwargs: Any
) -> None:
"""Register EnvSpec and EnvPool in global EnvRegistry."""
assert task_id not in self.specs
if "base_path" not in kwargs:
kwargs["base_path"] = base_path
self.specs[task_id] = (import_path, spec_cls, kwargs)
self.envpools[task_id] = {
"dm": (import_path, dm_cls),
"gym": (import_path, gym_cls),
"gymnasium": (import_path, gymnasium_cls)
}
def make(self, task_id: str, env_type: str, **kwargs: Any) -> Any:
"""Make envpool."""
new_gym_api = version.parse(gym.__version__) >= version.parse("0.26.0")
if "gym_reset_return_info" not in kwargs:
kwargs["gym_reset_return_info"] = new_gym_api
if new_gym_api and not kwargs["gym_reset_return_info"]:
raise ValueError(
"You are using gym>=0.26.0 but passed `gym_reset_return_info=False`. "
"The new gym API requires environments to return an info dictionary "
"after resets."
)
assert task_id in self.specs, \
f"{task_id} is not supported, `envpool.list_all_envs()` may help."
assert env_type in ["dm", "gym", "gymnasium"]
spec = self.make_spec(task_id, **kwargs)
import_path, envpool_cls = self.envpools[task_id][env_type]
return getattr(importlib.import_module(import_path), envpool_cls)(spec)
def make_dm(self, task_id: str, **kwargs: Any) -> Any:
"""Make dm_env compatible envpool."""
return self.make(task_id, "dm", **kwargs)
def make_gym(self, task_id: str, **kwargs: Any) -> Any:
"""Make gym.Env compatible envpool."""
return self.make(task_id, "gym", **kwargs)
def make_gymnasium(self, task_id: str, **kwargs: Any) -> Any:
"""Make gymnasium.Env compatible envpool."""
return self.make(task_id, "gymnasium", **kwargs)
def make_spec(self, task_id: str, **make_kwargs: Any) -> Any:
"""Make EnvSpec."""
import_path, spec_cls, kwargs = self.specs[task_id]
kwargs = {**kwargs, **make_kwargs}
# check arguments
if "seed" in kwargs: # Issue 214
INT_MAX = 2**31
assert -INT_MAX <= kwargs["seed"] < INT_MAX, \
f"Seed should be in range of int32, got {kwargs['seed']}"
if "num_envs" in kwargs:
assert kwargs["num_envs"] >= 1
if "batch_size" in kwargs:
assert 0 <= kwargs["batch_size"] <= kwargs["num_envs"]
if "max_num_players" in kwargs:
assert 1 <= kwargs["max_num_players"]
spec_cls = getattr(importlib.import_module(import_path), spec_cls)
config = spec_cls.gen_config(**kwargs)
return spec_cls(config)
def list_all_envs(self) -> List[str]:
"""Return all available task_id."""
return list(self.specs.keys())
# use a global EnvRegistry
registry = EnvRegistry()
register = registry.register
make = registry.make
make_dm = registry.make_dm
make_gym = registry.make_gym
make_gymnasium = registry.make_gymnasium
make_spec = registry.make_spec
list_all_envs = registry.list_all_envs
from ygoenv.python.api import py_env
from .ygopro_ygoenv import (
_YGOProEnvPool,
_YGOProEnvSpec,
init_module,
)
(
YGOProEnvSpec,
YGOProDMEnvPool,
YGOProGymEnvPool,
YGOProGymnasiumEnvPool,
) = py_env(_YGOProEnvSpec, _YGOProEnvPool)
__all__ = [
"YGOProEnvSpec",
"YGOProDMEnvPool",
"YGOProGymEnvPool",
"YGOProGymnasiumEnvPool",
]
from ygoenv.registration import register
register(
task_id="YGOPro-v0",
import_path="ygoenv.ygopro",
spec_cls="YGOProEnvSpec",
dm_cls="YGOProDMEnvPool",
gym_cls="YGOProGymEnvPool",
gymnasium_cls="YGOProGymnasiumEnvPool",
)
#include "ygoenv/ygopro/ygopro.h"
#include "ygoenv/core/py_envpool.h"
using YGOProEnvSpec = PyEnvSpec<ygopro::YGOProEnvSpec>;
using YGOProEnvPool = PyEnvPool<ygopro::YGOProEnvPool>;
PYBIND11_MODULE(ygopro_ygoenv, m) {
REGISTER(m, YGOProEnvSpec, YGOProEnvPool)
m.def("init_module", &ygopro::init_module);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment