Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
e5e5402a
Commit
e5e5402a
authored
Mar 06, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add more global features
parent
4b934828
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
311 additions
and
91 deletions
+311
-91
docs/feature_engineering.md
docs/feature_engineering.md
+8
-6
scripts/battle.py
scripts/battle.py
+229
-0
scripts/eval.py
scripts/eval.py
+9
-13
scripts/ppo_sp.py
scripts/ppo_sp.py
+2
-1
scripts/ppo_sp2.py
scripts/ppo_sp2.py
+2
-1
ygoai/rl/agent.py
ygoai/rl/agent.py
+19
-52
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+42
-18
No files found.
docs/feature_engineering.md
View file @
e5e5402a
...
@@ -19,19 +19,21 @@
...
@@ -19,19 +19,21 @@
-
lp: 2, max 65535 to 2 bytes
-
lp: 2, max 65535 to 2 bytes
-
oppo_lp: 2, max 65535 to 2 bytes
-
oppo_lp: 2, max 65535 to 2 bytes
-
n_my_decks: 1, int
-
n_my_decks: 1, int
-
n_my_extras:
-
n_my_hands:
-
n_my_hands:
-
n_my_graves:
-
n_my_removes:
-
n_my_monsters:
-
n_my_monsters:
-
n_my_spell_traps:
-
n_my_spell_traps:
-
n_my_graves:
-
n_my_removes:
-
n_my_extras:
-
n_op_decks:
-
n_op_decks:
-
n_op_extras:
-
n_op_hands:
-
n_op_hands:
-
n_op_graves:
-
n_op_removes:
-
n_op_monsters:
-
n_op_monsters:
-
n_op_spell_traps:
-
n_op_spell_traps:
-
n_op_graves:
-
n_op_removes:
-
n_op_extras:
-
n_my_hands: (another embed, to enhance)
-
n_op_hands: (another embed, to enhance)
-
turn: 1, int, trunc to 8
-
turn: 1, int, trunc to 8
-
phase: 1, int, one-hot (10)
-
phase: 1, int, one-hot (10)
-
is_first: 1, int, 0: False, 1: True
-
is_first: 1, int, 0: False, 1: True
...
...
scripts/battle.py
0 → 100644
View file @
e5e5402a
import
sys
import
time
import
os
import
random
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
import
ygoenv
import
numpy
as
np
import
optree
import
tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.buffer
import
create_obs
@
dataclass
class
Args
:
seed
:
int
=
1
"""the random seed"""
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
lang
:
str
=
"english"
"""the language to use"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
16
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
record
:
bool
=
False
"""whether to record the game as YGOPro replays"""
num_episodes
:
int
=
1024
"""the number of episodes to run"""
num_envs
:
int
=
64
"""the number of parallel game environments"""
verbose
:
bool
=
False
"""whether to print debug information"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
checkpoint1
:
Optional
[
str
]
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent"""
checkpoint2
:
Optional
[
str
]
=
"checkpoints/agent.pt"
"""the checkpoint to load for the second agent"""
compile
:
bool
=
True
"""if toggled, the model will be compiled"""
optimize
:
bool
=
False
"""if toggled, the model will be optimized"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
"""the number of threads to use for envpool, defaults to `num_envs`"""
def
predict_step
(
agent
,
obs
):
with
torch
.
no_grad
():
logits
,
values
,
_valid
=
agent
(
obs
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
return
probs
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
if
args
.
record
:
assert
args
.
num_envs
==
1
,
"Recording only works with a single environment"
assert
args
.
verbose
,
"Recording only works with verbose mode"
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
args
.
torch_threads
=
args
.
torch_threads
or
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"4"
))
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
seed
=
args
.
seed
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
import
torch
torch
.
manual_seed
(
args
.
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
args
.
torch_deterministic
torch
.
set_num_threads
(
args
.
torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
num_envs
=
args
.
num_envs
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
args
.
env_threads
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
player
=-
1
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
verbose
=
args
.
verbose
,
record
=
args
.
record
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
embedding_shape
=
args
.
num_embeddings
if
embedding_shape
is
None
:
with
open
(
args
.
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
1
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
if
not
args
.
compile
:
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
predict_step
=
torch
.
compile
(
predict_step
,
mode
=
'reduce-overhead'
)
else
:
if
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
def
optimize_for_inference
(
agent
):
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent1
=
optimize_for_inference
(
agent1
)
agent2
=
optimize_for_inference
(
agent2
)
obs
,
infos
=
envs
.
reset
()
next_to_play_
=
infos
[
'to_play'
]
episode_rewards
=
[]
episode_lengths
=
[]
win_rates
=
[]
win_reasons
=
[]
step
=
0
start
=
time
.
time
()
start_step
=
step
player1_
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
//
2
,
dtype
=
np
.
int64
)
])
player1
=
torch
.
from_numpy
(
player1_
)
.
to
(
device
=
device
)
model_time
=
env_time
=
0
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
start
=
time
.
time
()
start_step
=
step
model_time
=
env_time
=
0
_start
=
time
.
time
()
next_to_play
=
torch
.
from_numpy
(
next_to_play_
)
.
to
(
device
=
device
)
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
probs1
=
predict_step
(
agent1
,
obs
)
.
clone
()
probs2
=
predict_step
(
agent2
,
obs
)
.
clone
()
probs
=
torch
.
where
((
next_to_play
==
player1
)[:,
None
],
probs1
,
probs2
)
probs
=
probs
.
cpu
()
.
numpy
()
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play_
_start
=
time
.
time
()
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play_
=
infos
[
'to_play'
]
env_time
+=
time
.
time
()
-
_start
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
win_reason
=
infos
[
'win_reason'
][
idx
]
pl
=
1
if
to_play
[
idx
]
==
player1_
[
idx
]
else
-
1
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
total_time
=
time
.
time
()
-
start
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
\ No newline at end of file
scripts/eval.py
View file @
e5e5402a
...
@@ -150,24 +150,20 @@ if __name__ == "__main__":
...
@@ -150,24 +150,20 @@ if __name__ == "__main__":
# agent = agent.eval()
# agent = agent.eval()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
else
:
if
not
args
.
compile
:
state_dict
=
None
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
if
args
.
compile
:
if
state_dict
:
print
(
agent
.
load_state_dict
(
state_dict
))
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
else
:
elif
args
.
optimize
:
prefix
=
"_orig_mod."
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
if
state_dict
:
def
optimize_for_inference
(
agent
):
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
agent
=
torch
.
jit
.
optimize_for_inference
(
traced_model
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent
=
optimize_for_inference
(
agent
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
...
...
scripts/ppo_sp.py
View file @
e5e5402a
...
@@ -626,7 +626,8 @@ def run(local_rank, world_size):
...
@@ -626,7 +626,8 @@ def run(local_rank, world_size):
eval_stats
=
torch
.
tensor
([
eval_return
,
eval_ep_len
,
eval_win_rate
],
dtype
=
torch
.
float32
,
device
=
device
)
eval_stats
=
torch
.
tensor
([
eval_return
,
eval_ep_len
,
eval_win_rate
],
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
# sync the statistics
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
if
local_rank
==
0
:
if
local_rank
==
0
:
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
...
...
scripts/ppo_sp2.py
View file @
e5e5402a
...
@@ -633,7 +633,8 @@ def main():
...
@@ -633,7 +633,8 @@ def main():
eval_stats
=
torch
.
tensor
([
eval_return
,
eval_ep_len
,
eval_win_rate
],
dtype
=
torch
.
float32
,
device
=
device
)
eval_stats
=
torch
.
tensor
([
eval_return
,
eval_ep_len
,
eval_win_rate
],
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
# sync the statistics
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
eval_return
,
eval_ep_len
,
eval_win_rate
=
eval_stats
.
cpu
()
.
numpy
()
if
rank
==
0
:
if
rank
==
0
:
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
...
...
ygoai/rl/agent.py
View file @
e5e5402a
...
@@ -45,6 +45,9 @@ class Encoder(nn.Module):
...
@@ -45,6 +45,9 @@ class Encoder(nn.Module):
self
.
bin_points
=
nn
.
Parameter
(
bin_points
,
requires_grad
=
False
)
self
.
bin_points
=
nn
.
Parameter
(
bin_points
,
requires_grad
=
False
)
self
.
bin_intervals
=
nn
.
Parameter
(
bin_intervals
,
requires_grad
=
False
)
self
.
bin_intervals
=
nn
.
Parameter
(
bin_intervals
,
requires_grad
=
False
)
self
.
count_embed
=
nn
.
Embedding
(
100
,
c
//
16
)
self
.
hand_count_embed
=
nn
.
Embedding
(
100
,
c
//
16
)
if
embedding_shape
is
None
:
if
embedding_shape
is
None
:
n_embed
,
embed_dim
=
999
,
1024
n_embed
,
embed_dim
=
999
,
1024
elif
isinstance
(
embedding_shape
,
int
):
elif
isinstance
(
embedding_shape
,
int
):
...
@@ -88,12 +91,15 @@ class Encoder(nn.Module):
...
@@ -88,12 +91,15 @@ class Encoder(nn.Module):
self
.
if_first_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
if_first_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
is_my_turn_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
is_my_turn_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
global_norm_pre
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
my_deck_fc_emb
=
linear
(
1024
,
c
//
4
)
self
.
global_norm_pre
=
nn
.
LayerNorm
(
c
*
2
,
elementwise_affine
=
affine
)
self
.
global_net
=
nn
.
Sequential
(
self
.
global_net
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
),
nn
.
Linear
(
c
*
2
,
c
*
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Linear
(
c
,
c
),
nn
.
Linear
(
c
*
2
,
c
*
2
),
)
)
self
.
global_proj
=
nn
.
Linear
(
c
*
2
,
c
)
self
.
global_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
global_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
divisor
=
8
divisor
=
8
...
@@ -235,13 +241,20 @@ class Encoder(nn.Module):
...
@@ -235,13 +241,20 @@ class Encoder(nn.Module):
x_g_lp
=
self
.
lp_fc_emb
(
self
.
num_transform
(
x_global_1
[:,
0
:
2
]))
x_g_lp
=
self
.
lp_fc_emb
(
self
.
num_transform
(
x_global_1
[:,
0
:
2
]))
x_g_oppo_lp
=
self
.
oppo_lp_fc_emb
(
self
.
num_transform
(
x_global_1
[:,
2
:
4
]))
x_g_oppo_lp
=
self
.
oppo_lp_fc_emb
(
self
.
num_transform
(
x_global_1
[:,
2
:
4
]))
x_global_2
=
x
[:,
4
:
-
1
]
.
long
()
x_global_2
=
x
[:,
4
:
8
]
.
long
()
x_g_turn
=
self
.
turn_embed
(
x_global_2
[:,
0
])
x_g_turn
=
self
.
turn_embed
(
x_global_2
[:,
0
])
x_g_phase
=
self
.
phase_embed
(
x_global_2
[:,
1
])
x_g_phase
=
self
.
phase_embed
(
x_global_2
[:,
1
])
x_g_if_first
=
self
.
if_first_embed
(
x_global_2
[:,
2
])
x_g_if_first
=
self
.
if_first_embed
(
x_global_2
[:,
2
])
x_g_is_my_turn
=
self
.
is_my_turn_embed
(
x_global_2
[:,
3
])
x_g_is_my_turn
=
self
.
is_my_turn_embed
(
x_global_2
[:,
3
])
x_global
=
torch
.
cat
([
x_g_lp
,
x_g_oppo_lp
,
x_g_turn
,
x_g_phase
,
x_g_if_first
,
x_g_is_my_turn
],
dim
=-
1
)
x_global_3
=
x
[:,
8
:
22
]
.
long
()
x_g_cs
=
self
.
count_embed
(
x_global_3
)
.
flatten
(
1
)
x_g_my_hand_c
=
self
.
hand_count_embed
(
x_global_3
[:,
1
])
x_g_op_hand_c
=
self
.
hand_count_embed
(
x_global_3
[:,
8
])
x_global
=
torch
.
cat
([
x_g_lp
,
x_g_oppo_lp
,
x_g_turn
,
x_g_phase
,
x_g_if_first
,
x_g_is_my_turn
,
x_g_cs
,
x_g_my_hand_c
,
x_g_op_hand_c
],
dim
=-
1
)
return
x_global
return
x_global
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -278,6 +291,7 @@ class Encoder(nn.Module):
...
@@ -278,6 +291,7 @@ class Encoder(nn.Module):
x_global
=
self
.
encode_global
(
x_global
)
x_global
=
self
.
encode_global
(
x_global
)
x_global
=
self
.
global_norm_pre
(
x_global
)
x_global
=
self
.
global_norm_pre
(
x_global
)
f_global
=
x_global
+
self
.
global_net
(
x_global
)
f_global
=
x_global
+
self
.
global_net
(
x_global
)
f_global
=
self
.
global_proj
(
f_global
)
f_global
=
self
.
global_norm
(
f_global
)
f_global
=
self
.
global_norm
(
f_global
)
f_cards
=
f_cards
+
f_global
.
unsqueeze
(
1
)
f_cards
=
f_cards
+
f_global
.
unsqueeze
(
1
)
...
@@ -320,53 +334,6 @@ class Encoder(nn.Module):
...
@@ -320,53 +334,6 @@ class Encoder(nn.Module):
f_state
=
torch
.
cat
([
f_s_cards_global
,
f_s_actions_ha
],
dim
=-
1
)
f_state
=
torch
.
cat
([
f_s_cards_global
,
f_s_actions_ha
],
dim
=-
1
)
return
f_actions
,
f_state
,
mask
,
valid
return
f_actions
,
f_state
,
mask
,
valid
# class PPOCritic(nn.Module):
# def __init__(self, channels):
# super(PPOCritic, self).__init__()
# c = channels
# self.net = nn.Sequential(
# nn.Linear(c * 2, c // 2),
# nn.ReLU(),
# nn.Linear(c // 2, 1),
# )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
# def forward(self, f_actions, mask, action):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
# def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
class
Actor
(
nn
.
Module
):
class
Actor
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_transformer
=
False
):
def
__init__
(
self
,
channels
,
use_transformer
=
False
):
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
e5e5402a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <SQLiteCpp/SQLiteCpp.h>
#include <SQLiteCpp/SQLiteCpp.h>
#include <SQLiteCpp/VariadicBind.h>
#include <SQLiteCpp/VariadicBind.h>
#include <ankerl/unordered_dense.h>
#include <ankerl/unordered_dense.h>
#include <unordered_map>
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
#include "ygoenv/core/env.h"
...
@@ -1236,7 +1237,7 @@ public:
...
@@ -1236,7 +1237,7 @@ public:
int
n_action_feats
=
10
+
conf
[
"max_multi_select"
_
]
*
2
;
int
n_action_feats
=
10
+
conf
[
"max_multi_select"
_
]
*
2
;
return
MakeDict
(
return
MakeDict
(
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
40
})),
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
40
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
9
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
23
})),
"obs:actions_"
_
.
Bind
(
"obs:actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
"obs:h_actions_"
_
.
Bind
(
...
@@ -1650,17 +1651,32 @@ public:
...
@@ -1650,17 +1651,32 @@ public:
float
reward
=
0
;
float
reward
=
0
;
int
reason
=
0
;
int
reason
=
0
;
if
(
done_
)
{
if
(
done_
)
{
float
base_reward
=
1.0
;
float
base_reward
;
int
win_turn
=
turn_count_
-
winner_
;
if
(
winner_
==
0
)
{
if
(
win_turn
<=
1
)
{
if
(
turn_count_
<=
1
)
{
base_reward
=
8.0
;
// FTK
}
else
if
(
win_turn
<=
3
)
{
base_reward
=
16.0
;
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
3
)
{
}
else
if
(
win_turn
<=
5
)
{
base_reward
=
8.0
;
base_reward
=
2.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
7
)
{
base_reward
=
2.0
;
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
7
);
}
}
else
{
}
else
{
base_reward
=
0.5
+
1.0
/
(
win_turn
-
5
);
if
(
turn_count_
<=
1
)
{
base_reward
=
8.0
;
}
else
if
(
turn_count_
<=
3
)
{
base_reward
=
4.0
;
}
else
if
(
turn_count_
<=
5
)
{
base_reward
=
2.0
;
}
else
{
base_reward
=
0.5
+
1.0
/
(
turn_count_
-
5
);
}
}
}
if
(
play_mode_
==
kSelfPlay
)
{
if
(
play_mode_
==
kSelfPlay
)
{
// to_play_ is the previous player
// to_play_ is the previous player
reward
=
winner_
==
to_play_
?
base_reward
:
-
base_reward
;
reward
=
winner_
==
to_play_
?
base_reward
:
-
base_reward
;
...
@@ -1698,8 +1714,9 @@ public:
...
@@ -1698,8 +1714,9 @@ public:
private:
private:
using
SpecIndex
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint16_t
>
;
using
SpecIndex
=
ankerl
::
unordered_dense
::
map
<
std
::
string
,
uint16_t
>
;
void
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
SpecIndex
&
spec2index
,
std
::
tuple
<
SpecIndex
,
std
::
vector
<
int
>>
_set_obs_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
PlayerId
to_play
)
{
SpecIndex
spec2index
;
std
::
vector
<
int
>
loc_n_cards
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
bool
opponent
=
pi
==
1
;
const
bool
opponent
=
pi
==
1
;
...
@@ -1718,6 +1735,7 @@ private:
...
@@ -1718,6 +1735,7 @@ private:
}
}
if
(
opponent
&&
hidden_for_opponent
)
{
if
(
opponent
&&
hidden_for_opponent
)
{
auto
n_cards
=
YGO_QueryFieldCount
(
pduel_
,
player
,
location
);
auto
n_cards
=
YGO_QueryFieldCount
(
pduel_
,
player
,
location
);
loc_n_cards
.
push_back
(
n_cards
);
for
(
auto
i
=
0
;
i
<
n_cards
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
n_cards
;
i
++
)
{
f_cards
(
offset
,
2
)
=
location2id
.
at
(
location
);
f_cards
(
offset
,
2
)
=
location2id
.
at
(
location
);
f_cards
(
offset
,
4
)
=
1
;
f_cards
(
offset
,
4
)
=
1
;
...
@@ -1725,7 +1743,9 @@ private:
...
@@ -1725,7 +1743,9 @@ private:
}
}
}
else
{
}
else
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
player
,
location
);
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
player
,
location
);
for
(
int
i
=
0
;
i
<
cards
.
size
();
++
i
)
{
int
n_cards
=
cards
.
size
();
loc_n_cards
.
push_back
(
n_cards
);
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
const
auto
&
c
=
cards
[
i
];
auto
spec
=
c
.
get_spec
(
opponent
);
auto
spec
=
c
.
get_spec
(
opponent
);
bool
hide
=
false
;
bool
hide
=
false
;
...
@@ -1744,6 +1764,7 @@ private:
...
@@ -1744,6 +1764,7 @@ private:
}
}
}
}
}
}
return
{
spec2index
,
loc_n_cards
};
}
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
...
@@ -1797,7 +1818,7 @@ private:
...
@@ -1797,7 +1818,7 @@ private:
}
}
}
}
void
_set_obs_global
(
TArray
<
uint8_t
>
&
feat
,
PlayerId
player
)
{
void
_set_obs_global
(
TArray
<
uint8_t
>
&
feat
,
PlayerId
player
,
const
std
::
vector
<
int
>
&
loc_n_cards
)
{
uint8_t
me
=
player
;
uint8_t
me
=
player
;
uint8_t
op
=
1
-
player
;
uint8_t
op
=
1
-
player
;
...
@@ -1813,6 +1834,10 @@ private:
...
@@ -1813,6 +1834,10 @@ private:
feat
(
5
)
=
phase2id
.
at
(
current_phase_
);
feat
(
5
)
=
phase2id
.
at
(
current_phase_
);
feat
(
6
)
=
(
me
==
0
)
?
1
:
0
;
feat
(
6
)
=
(
me
==
0
)
?
1
:
0
;
feat
(
7
)
=
(
me
==
tp_
)
?
1
:
0
;
feat
(
7
)
=
(
me
==
tp_
)
?
1
:
0
;
for
(
int
i
=
0
;
i
<
loc_n_cards
.
size
();
i
++
)
{
feat
(
8
+
i
)
=
static_cast
<
uint8_t
>
(
loc_n_cards
[
i
]);
}
}
}
void
_set_obs_action_spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
j
,
void
_set_obs_action_spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
j
,
...
@@ -2148,14 +2173,13 @@ private:
...
@@ -2148,14 +2173,13 @@ private:
if
(
n_options
==
0
)
{
if
(
n_options
==
0
)
{
state
[
"info:num_options"
_
]
=
1
;
state
[
"info:num_options"
_
]
=
1
;
state
[
"obs:global_"
_
][
8
]
=
uint8_t
(
1
);
state
[
"obs:global_"
_
][
22
]
=
uint8_t
(
1
);
return
;
return
;
}
}
SpecIndex
spec2index
;
auto
[
spec2index
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
_set_obs_cards
(
state
[
"obs:cards_"
_
],
spec2index
,
to_play_
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
// we can't shuffle because idx must be stable in callback
// we can't shuffle because idx must be stable in callback
if
(
n_options
>
max_options
())
{
if
(
n_options
>
max_options
())
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment