Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
745c67f9
Commit
745c67f9
authored
Mar 13, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix multi select error
parent
15e1e4e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
12 deletions
+21
-12
scripts/battle.py
scripts/battle.py
+2
-2
scripts/eval.py
scripts/eval.py
+0
-1
scripts/ppo.py
scripts/ppo.py
+2
-0
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+17
-9
No files found.
scripts/battle.py
View file @
745c67f9
...
@@ -140,8 +140,8 @@ if __name__ == "__main__":
...
@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list
=
f
.
readlines
()
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
L
=
args
.
num_layers
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
...
...
scripts/eval.py
View file @
745c67f9
...
@@ -153,7 +153,6 @@ if __name__ == "__main__":
...
@@ -153,7 +153,6 @@ if __name__ == "__main__":
embedding_shape
=
len
(
code_list
)
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
2
,
embedding_shape
)
.
to
(
device
)
# agent = agent.eval()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
if
not
args
.
compile
:
if
not
args
.
compile
:
...
...
scripts/ppo.py
View file @
745c67f9
...
@@ -275,6 +275,8 @@ def main():
...
@@ -275,6 +275,8 @@ def main():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
else
:
traced_model
=
agent
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
num_steps
,
args
.
local_num_envs
),
device
)
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
745c67f9
...
@@ -1392,7 +1392,6 @@ protected:
...
@@ -1392,7 +1392,6 @@ protected:
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
ms_spec2idx_
;
ankerl
::
unordered_dense
::
map
<
std
::
string
,
int
>
ms_spec2idx_
;
std
::
vector
<
int
>
ms_r_idxs_
;
std
::
vector
<
int
>
ms_r_idxs_
;
// discard hand cards
// discard hand cards
bool
discard_hand_
=
false
;
bool
discard_hand_
=
false
;
...
@@ -1470,6 +1469,7 @@ public:
...
@@ -1470,6 +1469,7 @@ public:
}
}
turn_count_
=
0
;
turn_count_
=
0
;
ms_idx_
=
-
1
;
history_actions_0_
.
Zero
();
history_actions_0_
.
Zero
();
history_actions_1_
.
Zero
();
history_actions_1_
.
Zero
();
...
@@ -1710,7 +1710,6 @@ public:
...
@@ -1710,7 +1710,6 @@ public:
for
(
int
i
=
0
;
i
<
ms_r_idxs_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ms_r_idxs_
.
size
();
++
i
)
{
resp_buf_
[
i
+
1
]
=
ms_r_idxs_
[
i
];
resp_buf_
[
i
+
1
]
=
ms_r_idxs_
[
i
];
}
}
// fmt::println("{}, {}", ms_r_idxs_.size(), ms_r_idxs_);
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
YGO_SetResponseb
(
pduel_
,
resp_buf_
);
}
else
{
}
else
{
ms_idx_
++
;
ms_idx_
++
;
...
@@ -1750,6 +1749,14 @@ public:
...
@@ -1750,6 +1749,14 @@ public:
fmt
::
println
(
"turn: {}, phase: {}, tplayer: {}"
,
turn_count_
,
phase_to_string
(
current_phase_
),
tp_
);
fmt
::
println
(
"turn: {}, phase: {}, tplayer: {}"
,
turn_count_
,
phase_to_string
(
current_phase_
),
tp_
);
}
}
void
show_buffer
()
const
{
fmt
::
println
(
"msg: {}, dp: {}, dl: {}"
,
msg_to_string
(
msg_
),
dp_
,
dl_
);
for
(
int
i
=
0
;
i
<
dl_
;
++
i
)
{
fmt
::
print
(
"{:02x} "
,
data_
[
i
]);
}
fmt
::
print
(
"
\n
"
);
}
void
show_deck
(
PlayerId
player
)
const
{
void
show_deck
(
PlayerId
player
)
const
{
fmt
::
print
(
"Player {}'s deck:
\n
"
,
player
);
fmt
::
print
(
"Player {}'s deck:
\n
"
,
player
);
show_deck
(
player
==
0
?
main_deck0_
:
main_deck1_
,
"Main"
);
show_deck
(
player
==
0
?
main_deck0_
:
main_deck1_
,
"Main"
);
...
@@ -1997,11 +2004,16 @@ private:
...
@@ -1997,11 +2004,16 @@ private:
if
(
it
==
spec2index
.
end
())
{
if
(
it
==
spec2index
.
end
())
{
// TODO: find the root cause
// TODO: find the root cause
// print spec2index
// print spec2index
fmt
::
println
(
"Spec2index:"
);
show_deck
(
0
);
show_deck
(
1
);
show_buffer
();
show_turn
();
fmt
::
println
(
"MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}"
,
ms_idx_
,
ms_mode_
,
ms_min_
,
ms_max_
,
ms_must_
,
ms_specs_
,
ms_combs_
);
fmt
::
println
(
"Spec: {}, Spec2index:"
,
spec
);
for
(
auto
&
[
k
,
v
]
:
spec2index
)
{
for
(
auto
&
[
k
,
v
]
:
spec2index
)
{
fmt
::
println
(
"{}: {}"
,
k
,
v
);
fmt
::
println
(
"{}: {}"
,
k
,
v
);
}
}
//
throw std::runtime_error("Spec not found: " + spec);
throw
std
::
runtime_error
(
"Spec not found: "
+
spec
);
idx
=
1
;
idx
=
1
;
}
else
{
}
else
{
idx
=
it
->
second
;
idx
=
it
->
second
;
...
@@ -4533,11 +4545,7 @@ private:
...
@@ -4533,11 +4545,7 @@ private:
}
else
{
}
else
{
show_deck
(
0
);
show_deck
(
0
);
show_deck
(
1
);
show_deck
(
1
);
// print byte by byte
show_buffer
();
for
(
int
i
=
0
;
i
<
dp_
;
++
i
)
{
fmt
::
print
(
"{:02x} "
,
data_
[
i
]);
}
fmt
::
print
(
"
\n
"
);
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
fmt
::
format
(
"Unknown message {}, length {}, dp {}"
,
fmt
::
format
(
"Unknown message {}, length {}, dp {}"
,
msg_to_string
(
msg_
),
dl_
,
dp_
));
msg_to_string
(
msg_
),
dl_
,
dp_
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment