Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
e79a43ef
Commit
e79a43ef
authored
Apr 07, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use both history actions (cheat)
parent
892c7364
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
57 deletions
+56
-57
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+14
-8
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+42
-49
No files found.
ygoai/rl/jax/agent2.py
View file @
e79a43ef
...
...
@@ -18,7 +18,7 @@ class ActionEncoder(nn.Module):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
...
...
@@ -26,7 +26,6 @@ class ActionEncoder(nn.Module):
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
13
,
c
//
div
)(
x
[:,
:,
1
])
x_a_yesno
=
embed
(
3
,
c
//
div
)(
x
[:,
:,
2
])
...
...
@@ -38,9 +37,9 @@ class ActionEncoder(nn.Module):
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
8
])
x_a_place
=
embed
(
31
,
c
//
div
//
2
)(
x
[:,
:,
9
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
10
])
return
jnp
.
concatenate
([
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
],
axis
=-
1
)
xs
=
[
x_a_msg
,
x_a_act
,
x_a_yesno
,
x_a_phase
,
x_a_cancel
,
x_a_finish
,
x_a_position
,
x_a_option
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
...
...
@@ -169,7 +168,8 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
...
...
@@ -216,7 +216,13 @@ class Encoder(nn.Module):
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
id_embed
(
x_h_id
))
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
2
:])
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_player
=
embed
(
2
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
14
])
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
...
...
@@ -240,7 +246,7 @@ class Encoder(nn.Module):
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
x_a_feats
=
action_encoder
(
x_actions
[
...
,
2
:]
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
e79a43ef
...
...
@@ -1263,7 +1263,7 @@ public:
"obs:actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
})),
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
...
...
@@ -1391,15 +1391,10 @@ protected:
const
int
n_history_actions_
;
// circular buffer for history actions of player 0
TArray
<
uint8_t
>
history_actions_0_
;
int
ha_p_0_
=
0
;
std
::
vector
<
CardId
>
h_card_ids_0_
;
// circular buffer for history actions of player 1
TArray
<
uint8_t
>
history_actions_1_
;
int
ha_p_1_
=
0
;
std
::
vector
<
CardId
>
h_card_ids_1_
;
// circular buffer for history actions
TArray
<
uint8_t
>
history_actions_
;
int
ha_p_
=
0
;
std
::
vector
<
CardId
>
h_card_ids_
;
std
::
unordered_set
<
std
::
string
>
revealed_
;
...
...
@@ -1461,12 +1456,9 @@ public:
int
max_options
=
spec
.
config
[
"max_options"
_
];
int
n_action_feats
=
spec
.
state_spec
[
"obs:actions_"
_
].
shape
[
1
];
h_card_ids_0_
.
resize
(
max_options
);
h_card_ids_1_
.
resize
(
max_options
);
history_actions_0_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
})));
history_actions_1_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
})));
h_card_ids_
.
resize
(
max_options
);
history_actions_
=
TArray
<
uint8_t
>
(
Array
(
ShapeSpec
(
sizeof
(
uint8_t
),
{
n_history_actions_
,
n_action_feats
+
2
})));
}
~
YGOProEnv
()
{
...
...
@@ -1537,10 +1529,8 @@ public:
turn_count_
=
0
;
ms_idx_
=
-
1
;
history_actions_0_
.
Zero
();
history_actions_1_
.
Zero
();
ha_p_0_
=
0
;
ha_p_1_
=
0
;
history_actions_
.
Zero
();
ha_p_
=
0
;
clock_t
_start
=
clock
();
...
...
@@ -1803,23 +1793,22 @@ public:
}
void
update_h_card_ids
(
PlayerId
player
,
int
idx
)
{
auto
&
h_card_ids
=
player
==
0
?
h_card_ids_0_
:
h_card_ids_1_
;
h_card_ids
[
idx
]
=
parse_card_id
(
options_
[
idx
],
player
);
h_card_ids_
[
idx
]
=
parse_card_id
(
options_
[
idx
],
player
);
}
void
update_history_actions
(
PlayerId
player
,
int
idx
)
{
auto
&
history_actions
=
player
==
0
?
history_actions_0_
:
history_actions_1_
;
auto
&
ha_p
=
player
==
0
?
ha_p_0_
:
ha_p_1_
;
const
auto
&
h_card_ids
=
player
==
0
?
h_card_ids_0_
:
h_card_ids_1_
;
ha_p
--
;
if
(
ha_p
<
0
)
{
ha_p
=
n_history_actions_
-
1
;
if
((
msg_
==
MSG_SELECT_CHAIN
)
&
(
options_
[
idx
][
0
]
==
'c'
))
{
return
;
}
history_actions
[
ha_p
].
Zero
();
_set_obs_action
(
history_actions
,
ha_p
,
msg_
,
options_
[
idx
],
{},
h_card_ids
[
idx
]);
ha_p_
--
;
if
(
ha_p_
<
0
)
{
ha_p_
=
n_history_actions_
-
1
;
}
history_actions_
[
ha_p_
].
Zero
();
_set_obs_action
(
history_actions_
,
ha_p_
,
msg_
,
options_
[
idx
],
{},
h_card_ids_
[
idx
]);
history_actions_
[
ha_p_
](
13
)
=
static_cast
<
uint8_t
>
(
player
);
history_actions_
[
ha_p_
](
14
)
=
static_cast
<
uint8_t
>
(
turn_count_
);
}
void
show_deck
(
const
std
::
vector
<
CardCode
>
&
deck
,
const
std
::
string
&
prefix
)
const
{
...
...
@@ -1849,7 +1838,7 @@ public:
}
void
show_history_actions
(
PlayerId
player
)
const
{
const
auto
&
ha
=
player
==
0
?
history_actions_0_
:
history_actions_1
_
;
const
auto
&
ha
=
history_actions
_
;
// print card ids of history actions
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
fmt
::
print
(
"history {}
\n
"
,
i
);
...
...
@@ -2064,7 +2053,7 @@ private:
feat
(
2
)
=
op_lp_1
;
feat
(
3
)
=
op_lp_2
;
feat
(
4
)
=
std
::
min
(
turn_count_
,
8
);
feat
(
4
)
=
std
::
min
(
turn_count_
,
16
);
feat
(
5
)
=
phase2id
.
at
(
current_phase_
);
feat
(
6
)
=
(
me
==
0
)
?
1
:
0
;
feat
(
7
)
=
(
me
==
tp_
)
?
1
:
0
;
...
...
@@ -2407,34 +2396,38 @@ private:
n_options
=
options_
.
size
();
state
[
"info:num_options"
_
]
=
n_options
;
// update h_card_ids from state
auto
&
h_card_ids
=
to_play_
==
0
?
h_card_ids_0_
:
h_card_ids_1_
;
// update_h_card_ids from state
for
(
int
i
=
0
;
i
<
n_options
;
++
i
)
{
uint8_t
spec_index1
=
state
[
"obs:actions_"
_
](
i
,
0
);
uint8_t
spec_index2
=
state
[
"obs:actions_"
_
](
i
,
1
);
uint16_t
spec_index
=
(
static_cast
<
uint16_t
>
(
spec_index1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
spec_index2
);
if
(
spec_index
==
0
)
{
h_card_ids
[
i
]
=
0
;
h_card_ids
_
[
i
]
=
0
;
}
else
{
uint8_t
card_id1
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
0
);
uint8_t
card_id2
=
state
[
"obs:cards_"
_
](
spec_index
-
1
,
1
);
h_card_ids
[
i
]
=
(
static_cast
<
uint16_t
>
(
card_id1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
card_id2
);
h_card_ids
_
[
i
]
=
(
static_cast
<
uint16_t
>
(
card_id1
)
<<
8
)
+
static_cast
<
uint16_t
>
(
card_id2
);
}
}
// write history actions
const
auto
&
ha_p
=
to_play_
==
0
?
ha_p_0_
:
ha_p_1_
;
const
auto
&
history_actions
=
to_play_
==
0
?
history_actions_0_
:
history_actions_1_
;
int
n1
=
n_history_actions_
-
ha_p
;
int
n_action_feats
=
state
[
"obs:actions_"
_
].
Shape
()[
1
];
int
offset
=
n_history_actions_
-
ha_p_
;
int
n_h_action_feats
=
history_actions_
.
Shape
()[
1
];
state
[
"obs:h_actions_"
_
].
Assign
((
uint8_t
*
)
history_actions
[
ha_p
].
Data
(),
n_action_feats
*
n1
);
state
[
"obs:h_actions_"
_
][
n1
].
Assign
((
uint8_t
*
)
history_actions
.
Data
(),
n_action_feats
*
ha_p
);
state
[
"obs:h_actions_"
_
].
Assign
(
(
uint8_t
*
)
history_actions_
[
ha_p_
].
Data
(),
n_h_action_feats
*
offset
);
state
[
"obs:h_actions_"
_
][
offset
].
Assign
(
(
uint8_t
*
)
history_actions_
.
Data
(),
n_h_action_feats
*
ha_p_
);
for
(
int
i
=
0
;
i
<
n_history_actions_
;
++
i
)
{
if
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
2
))
==
0
)
{
break
;
}
state
[
"obs:h_actions_"
_
](
i
,
13
)
=
static_cast
<
uint8_t
>
(
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
13
))
==
to_play_
);
int
turn_diff
=
std
::
min
(
16
,
turn_count_
-
uint8_t
(
state
[
"obs:h_actions_"
_
](
i
,
14
)));
state
[
"obs:h_actions_"
_
](
i
,
14
)
=
static_cast
<
uint8_t
>
(
turn_diff
);
}
}
void
show_decision
(
int
idx
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment