Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
8293ff41
Commit
8293ff41
authored
Feb 23, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support replay
parent
1925293b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
17 deletions
+51
-17
scripts/eval.py
scripts/eval.py
+7
-0
scripts/ppo.py
scripts/ppo.py
+5
-3
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+39
-14
No files found.
scripts/eval.py
View file @
8293ff41
...
...
@@ -50,6 +50,8 @@ class Args:
"""whether to play the game"""
selfplay
:
bool
=
False
"""whether to use selfplay"""
record
:
bool
=
False
"""whether to record the game as YGOPro replays"""
num_episodes
:
int
=
1024
"""the number of episodes to run"""
...
...
@@ -87,6 +89,10 @@ if __name__ == "__main__":
if
args
.
play
:
args
.
num_envs
=
1
args
.
verbose
=
True
if
args
.
record
:
assert
args
.
num_envs
==
1
,
"Recording only works with a single environment"
assert
args
.
verbose
,
"Recording only works with verbose mode"
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
args
.
torch_threads
=
args
.
torch_threads
or
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"4"
))
...
...
@@ -125,6 +131,7 @@ if __name__ == "__main__":
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'human'
if
args
.
play
else
(
'self'
if
args
.
selfplay
else
(
'bot'
if
args
.
bot_type
==
"greedy"
else
"random"
)),
verbose
=
args
.
verbose
,
record
=
args
.
record
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
...
...
scripts/ppo.py
View file @
8293ff41
...
...
@@ -112,8 +112,10 @@ class Args:
"""tensorboard log directory"""
ckpt_dir
:
str
=
"./checkpoints"
"""checkpoint directory"""
save_interval
:
int
=
100
save_interval
:
int
=
100
0
"""the number of iterations to save the model"""
log_p
:
float
=
0.1
"""the probability of logging"""
port
:
int
=
12355
"""the port to use for distributed training"""
...
...
@@ -339,7 +341,7 @@ def run(local_rank, world_size):
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
if
d
and
random
.
random
()
<
args
.
log_p
:
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
...
...
@@ -420,7 +422,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if
local_rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"
{iteration}
.pth"
))
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"
agent
.pth"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
8293ff41
...
...
@@ -1359,7 +1359,7 @@ protected:
bool
record_
=
false
;
// uint8_t *replay_data_;
// uint8_t *rdata_;
FILE* fp_;
FILE
*
fp_
=
nullptr
;
bool
is_recording
=
false
;
public:
...
...
@@ -1454,7 +1454,7 @@ public:
ha_p_0_
=
0
;
ha_p_1_
=
0
;
unsigned long
duel_seed = dist_int_(gen_);
auto
duel_seed
=
dist_int_
(
gen_
);
std
::
unique_lock
<
std
::
shared_timed_mutex
>
ulock
(
duel_mtx
);
pduel_
=
OCG_CreateDuel
(
duel_seed
);
...
...
@@ -1580,12 +1580,12 @@ public:
if
(
done_
)
{
float
base_reward
=
1.0
;
int
win_turn
=
turn_count_
-
winner_
;
if (win_turn <=
5
) {
base_reward =
2
.0;
if
(
win_turn
<=
1
)
{
base_reward
=
4
.0
;
}
else
if
(
win_turn
<=
3
)
{
base_reward
=
3.0
;
} else if (win_turn <=
1
) {
base_reward =
4
.0;
}
else
if
(
win_turn
<=
5
)
{
base_reward
=
2
.0
;
}
if
(
play_mode_
==
kSelfPlay
)
{
// to_play_ is the previous player
...
...
@@ -1599,6 +1599,14 @@ public:
}
else
if
(
win_reason_
==
0x02
)
{
reason
=
-
1
;
}
if
(
record_
)
{
if
(
!
is_recording
||
fp_
==
nullptr
)
{
throw
std
::
runtime_error
(
"Recording is not started"
);
}
fclose
(
fp_
);
is_recording
=
false
;
}
}
WriteState
(
reward
,
win_reason_
);
...
...
@@ -1942,11 +1950,12 @@ private:
void
str_to_uint16
(
const
char
*
src
,
uint16_t
*
dest
)
{
for (int i = 0; i < strlen(src); i +=
2
) {
dest[i / 2] = src[i] | (src[i + 1] << 8)
;
for
(
int
i
=
0
;
i
<
strlen
(
src
);
i
+=
1
)
{
dest
[
i
]
=
src
[
i
]
;
}
// Add null terminator
dest[
(strlen(src) + 1) / 2
] = '\0';
dest
[
strlen
(
src
)
+
1
]
=
'\0'
;
}
void
ReplayWriteInt8
(
int8_t
value
)
{
...
...
@@ -1958,7 +1967,7 @@ private:
}
// ygopro-core API
intptr_t OCG_CreateDuel(uint
_fast
32_t seed) {
intptr_t
OCG_CreateDuel
(
uint32_t
seed
)
{
if
(
record_
)
{
ReplayHeader
rh
;
rh
.
id
=
0x31707279
;
...
...
@@ -1969,18 +1978,21 @@ private:
fwrite
(
&
rh
,
sizeof
(
rh
),
1
,
fp_
);
fflush
(
fp_
);
}
return create_duel(seed);
std
::
mt19937
rnd
(
seed
);
return
create_duel
(
rnd
());
}
void
OCG_SetPlayerInfo
(
intptr_t
pduel
,
int32
playerid
,
int32
lp
,
int32
startcount
,
int32
drawcount
)
{
if
(
record_
&&
playerid
==
0
)
{
{
uint16_t
name
[
20
];
memset
(
name
,
0
,
40
);
str_to_uint16
(
"Alice"
,
name
);
fwrite
(
name
,
40
,
1
,
fp_
);
}
{
uint16_t
name
[
20
];
memset
(
name
,
0
,
40
);
str_to_uint16
(
"Bob"
,
name
);
fwrite
(
name
,
40
,
1
,
fp_
);
}
...
...
@@ -2030,7 +2042,7 @@ private:
void
OCG_SetResponsei
(
intptr_t
pduel
,
int32
value
)
{
if
(
record_
)
{
ReplayWriteInt
32
(4);
ReplayWriteInt
8
(
4
);
ReplayWriteInt32
(
value
);
}
set_responsei
(
pduel
,
value
);
...
...
@@ -2038,8 +2050,21 @@ private:
void
OCG_SetResponseb
(
intptr_t
pduel
,
byte
*
buf
)
{
if
(
record_
)
{
ReplayWriteInt8(buf[0]);
fwrite(buf + 1, buf[0], 1, fp_);
switch
(
msg_
)
{
case
MSG_SORT_CARD
:
ReplayWriteInt8
(
1
);
fwrite
(
buf
,
1
,
1
,
fp_
);
break
;
case
MSG_SELECT_PLACE
:
case
MSG_SELECT_DISFIELD
:
ReplayWriteInt8
(
3
);
fwrite
(
buf
,
3
,
1
,
fp_
);
break
;
default:
ReplayWriteInt8
(
buf
[
0
]
+
1
);
fwrite
(
buf
,
buf
[
0
]
+
1
,
1
,
fp_
);
break
;
}
}
set_responseb
(
pduel
,
buf
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment