Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ad9c3c34
Commit
ad9c3c34
authored
Feb 20, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add turn feature
parent
0c7cfc92
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
14 deletions
+28
-14
docs/features.md
docs/features.md
+1
-1
scripts/eval.py
scripts/eval.py
+2
-1
ygoai/rl/agent.py
ygoai/rl/agent.py
+9
-7
ygoai/rl/buffer.py
ygoai/rl/buffer.py
+10
-0
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+6
-5
No files found.
docs/features.md
View file @
ad9c3c34
...
...
@@ -17,7 +17,7 @@
## Global
-
lp: 2, max 65535 to 2 bytes
-
oppo_lp: 2, max 65535 to 2 bytes
<!-- - turn: 8, int, trunc to 8 -->
-
turn: 1, int, trunc to 8
-
phase: 1, int, one-hot (10)
-
is_first: 1, int, 0: False, 1: True
-
is_my_turn: 1, int, 0: False, 1: True
...
...
scripts/eval.py
View file @
ad9c3c34
...
...
@@ -15,6 +15,7 @@ import tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.agent
import
Agent
from
ygoai.rl.buffer
import
create_obs
@
dataclass
...
...
@@ -143,7 +144,7 @@ if __name__ == "__main__":
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
agent
.
load_state_dict
(
state_dict
)
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
envs
.
reset
()[
0
]
)
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
agent
=
torch
.
jit
.
optimize_for_inference
(
traced_model
)
...
...
ygoai/rl/agent.py
View file @
ad9c3c34
...
...
@@ -78,7 +78,8 @@ class Agent(nn.Module):
self
.
lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
oppo_lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
phase_embed
=
nn
.
Embedding
(
10
,
c
//
4
)
self
.
turn_embed
=
nn
.
Embedding
(
20
,
c
//
8
)
self
.
phase_embed
=
nn
.
Embedding
(
10
,
c
//
8
)
self
.
if_first_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
is_my_turn_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
...
...
@@ -231,11 +232,12 @@ class Agent(nn.Module):
x_g_oppo_lp
=
self
.
oppo_lp_fc_emb
(
self
.
num_transform
(
x_global_1
[:,
2
:
4
]))
x_global_2
=
x
[:,
4
:
-
1
]
.
long
()
x_g_phase
=
self
.
phase_embed
(
x_global_2
[:,
0
])
x_g_if_first
=
self
.
if_first_embed
(
x_global_2
[:,
1
])
x_g_is_my_turn
=
self
.
is_my_turn_embed
(
x_global_2
[:,
2
])
x_g_turn
=
self
.
turn_embed
(
x_global_2
[:,
0
])
x_g_phase
=
self
.
phase_embed
(
x_global_2
[:,
1
])
x_g_if_first
=
self
.
if_first_embed
(
x_global_2
[:,
2
])
x_g_is_my_turn
=
self
.
is_my_turn_embed
(
x_global_2
[:,
3
])
x_global
=
torch
.
cat
([
x_g_lp
,
x_g_oppo_lp
,
x_g_phase
,
x_g_if_first
,
x_g_is_my_turn
],
dim
=-
1
)
x_global
=
torch
.
cat
([
x_g_lp
,
x_g_oppo_lp
,
x_g_
turn
,
x_g_
phase
,
x_g_if_first
,
x_g_is_my_turn
],
dim
=-
1
)
return
x_global
def
forward
(
self
,
x
):
...
...
@@ -308,6 +310,6 @@ class Agent(nn.Module):
f_actions
=
self
.
action_norm
(
f_actions
)
values
=
self
.
value_head
(
f_actions
)[
...
,
0
]
values
=
torch
.
tanh
(
values
)
values
=
torch
.
where
(
mask
,
torch
.
full_like
(
values
,
-
1
.01
),
values
)
#
values = torch.tanh(values)
values
=
torch
.
where
(
mask
,
torch
.
full_like
(
values
,
-
1
0
),
values
)
return
values
,
valid
\ No newline at end of file
ygoai/rl/buffer.py
View file @
ad9c3c34
...
...
@@ -510,6 +510,16 @@ class DMCBuffer:
return
data
def
create_obs
(
observation_space
:
spaces
.
Dict
,
shape
:
Tuple
[
int
,
...
],
device
:
Union
[
th
.
device
,
str
]
=
"cpu"
):
obs_shape
=
get_obs_shape
(
observation_space
)
obs
=
{
key
:
th
.
zeros
(
(
*
shape
,
*
_obs_shape
),
dtype
=
dtype_dict
[
observation_space
[
key
]
.
dtype
.
type
],
device
=
device
)
for
key
,
_obs_shape
in
obs_shape
.
items
()
}
return
obs
class
DMCDictBuffer
:
observation_space
:
spaces
.
Dict
obs_shape
:
Dict
[
str
,
Tuple
[
int
,
...
]]
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
ad9c3c34
...
...
@@ -1188,7 +1188,7 @@ public:
int
n_action_feats
=
9
+
conf
[
"max_multi_select"
_
]
*
2
;
return
MakeDict
(
"obs:cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
39
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
8
})),
"obs:global_"
_
.
Bind
(
Spec
<
uint8_t
>
({
9
})),
"obs:actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
...
...
@@ -1659,9 +1659,10 @@ private:
feat
(
2
)
=
op_lp_1
;
feat
(
3
)
=
op_lp_2
;
feat
(
4
)
=
phase2id
.
at
(
current_phase_
);
feat
(
5
)
=
(
me
==
0
)
?
1
:
0
;
feat
(
6
)
=
(
me
==
tp_
)
?
1
:
0
;
feat
(
4
)
=
std
::
min
(
turn_count_
,
8
);
feat
(
5
)
=
phase2id
.
at
(
current_phase_
);
feat
(
6
)
=
(
me
==
0
)
?
1
:
0
;
feat
(
7
)
=
(
me
==
tp_
)
?
1
:
0
;
}
void
_set_obs_action_spec
(
TArray
<
uint8_t
>
&
feat
,
int
i
,
int
j
,
...
...
@@ -1883,7 +1884,7 @@ private:
if
(
n_options
==
0
)
{
state
[
"info:num_options"
_
]
=
1
;
state
[
"obs:global_"
_
][
7
]
=
uint8_t
(
1
);
state
[
"obs:global_"
_
][
8
]
=
uint8_t
(
1
);
return
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment