Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
77492ca0
Commit
77492ca0
authored
Jun 07, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add oppo_info
parent
3dfee5f5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
146 additions
and
39 deletions
+146
-39
scripts/cleanba.py
scripts/cleanba.py
+1
-0
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+108
-34
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+37
-5
No files found.
scripts/cleanba.py
View file @
77492ca0
...
@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
...
@@ -211,6 +211,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
play_mode
=
mode
,
timeout
=
args
.
timeout
,
timeout
=
args
.
timeout
,
oppo_info
=
args
.
m2
.
oppo_info
if
eval
else
args
.
m1
.
oppo_info
,
)
)
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
return
envs
return
envs
...
...
ygoai/rl/jax/agent.py
View file @
77492ca0
...
@@ -113,7 +113,7 @@ class CardEncoder(nn.Module):
...
@@ -113,7 +113,7 @@ class CardEncoder(nn.Module):
c_mask
=
x_loc
==
0
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_owner
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
...
@@ -208,6 +208,7 @@ class Encoder(nn.Module):
...
@@ -208,6 +208,7 @@ class Encoder(nn.Module):
card_mask
:
bool
=
False
card_mask
:
bool
=
False
noam
:
bool
=
False
noam
:
bool
=
False
action_feats
:
bool
=
True
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
version
:
int
=
0
version
:
int
=
0
@
nn
.
compact
@
nn
.
compact
...
@@ -227,28 +228,46 @@ class Encoder(nn.Module):
...
@@ -227,28 +228,46 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards_g
=
x
[
'g_cards_'
]
if
self
.
oppo_info
else
None
x_cards
=
x
[
'cards_'
]
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
batch_size
=
x_
cards
.
shape
[
0
]
batch_size
=
x_
global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
valid
=
x_global
[:,
-
1
]
==
0
n_cards
=
x_cards
.
shape
[
-
2
]
if
self
.
oppo_info
:
x_cards
=
jnp
.
concatenate
([
x_cards
,
x_cards_g
],
axis
=-
2
)
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
f_cards
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:])
if
self
.
oppo_info
:
f_cards_me
,
f_cards_g
=
jnp
.
split
(
f_cards
,
[
n_cards
],
axis
=-
2
)
else
:
f_cards_me
,
f_cards_g
=
f_cards
,
None
# Cards
# Cards
f_cards
,
c_mask
=
CardEncoder
(
fs_g_card
=
[]
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_id
,
x_cards
[:,
:,
2
:])
for
i
,
f_cards
in
enumerate
([
f_cards_g
,
f_cards_me
]):
if
f_cards
is
None
:
fs_g_card
.
append
(
None
)
continue
name
=
'g_card_embed'
if
i
==
0
else
'g_g_card_embed'
g_card_embed
=
self
.
param
(
g_card_embed
=
self
.
param
(
'g_card_embed'
,
name
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
...
@@ -265,6 +284,8 @@ class Encoder(nn.Module):
...
@@ -265,6 +284,8 @@ class Encoder(nn.Module):
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
f_g_card
=
f_cards
[:,
0
]
fs_g_card
.
append
(
f_g_card
)
f_g_g_card
,
f_g_card
=
fs_g_card
# Global
# Global
x_global
=
GlobalEncoder
(
x_global
=
GlobalEncoder
(
...
@@ -412,7 +433,8 @@ class Encoder(nn.Module):
...
@@ -412,7 +433,8 @@ class Encoder(nn.Module):
else
:
else
:
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
return
f_actions
,
f_state
,
f_g_g_card
,
a_mask
,
valid
class
Actor
(
nn
.
Module
):
class
Actor
(
nn
.
Module
):
...
@@ -473,7 +495,29 @@ class Critic(nn.Module):
...
@@ -473,7 +495,29 @@ class Critic(nn.Module):
return
x
return
x
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
):
class
GlobalCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
rstate1
,
rstate2
,
g_cards
):
f_state
=
jnp
.
concatenate
([
rstate1
[
0
],
rstate1
[
1
],
rstate2
[
0
],
rstate2
[
0
]],
axis
=-
1
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
c
=
self
.
channels
[
-
1
]
t
=
nn
.
Dense
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
g_cards
)
s
,
b
=
jnp
.
split
(
t
,
2
,
axis
=-
1
)
x
=
x
*
s
+
b
x
=
mlp
([
c
],
last_lin
=
False
)(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
,
return_state
=
False
):
if
main
is
not
None
:
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
...
@@ -484,10 +528,13 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
...
@@ -484,10 +528,13 @@ def rnn_step_by_main(rnn_layer, rstate, f_state, done, main):
rstate
=
rstate1
,
rstate2
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
if
return_state
:
return
rstate
,
(
f_state
,
rstate
)
else
:
return
rstate
,
f_state
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
,
return_state
=
False
):
if
switch
:
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
init_rstate2
=
carry
...
@@ -497,7 +544,7 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
...
@@ -497,7 +544,7 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
return
(
rstate
,
init_rstate2
),
y
return
(
rstate
,
init_rstate2
),
y
else
:
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
)
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
,
return_state
)
scan
=
nn
.
scan
(
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
split_rngs
=
{
'params'
:
False
})
...
@@ -531,6 +578,8 @@ class ModelArgs(EncoderArgs):
...
@@ -531,6 +578,8 @@ class ModelArgs(EncoderArgs):
"""the type of RNN to use, None for no RNN"""
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
film
:
bool
=
False
"""whether to use FiLM for the actor"""
"""whether to use FiLM for the actor"""
oppo_info
:
bool
=
False
"""whether to use opponent's information"""
rwkv_head_size
:
int
=
32
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
"""the head size for the RWKV"""
...
@@ -546,6 +595,7 @@ class RNNAgent(nn.Module):
...
@@ -546,6 +595,7 @@ class RNNAgent(nn.Module):
noam
:
bool
=
False
noam
:
bool
=
False
rwkv_head_size
:
int
=
32
rwkv_head_size
:
int
=
32
action_feats
:
bool
=
True
action_feats
:
bool
=
True
oppo_info
:
bool
=
False
version
:
int
=
0
version
:
int
=
0
switch
:
bool
=
True
switch
:
bool
=
True
...
@@ -557,6 +607,8 @@ class RNNAgent(nn.Module):
...
@@ -557,6 +607,8 @@ class RNNAgent(nn.Module):
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
):
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
c
=
self
.
num_channels
c
=
self
.
num_channels
oc
=
self
.
rnn_channels
if
self
.
rnn_type
==
'rwkv'
else
None
oc
=
self
.
rnn_channels
if
self
.
rnn_type
==
'rwkv'
else
None
encoder
=
Encoder
(
encoder
=
Encoder
(
...
@@ -571,10 +623,11 @@ class RNNAgent(nn.Module):
...
@@ -571,10 +623,11 @@ class RNNAgent(nn.Module):
card_mask
=
self
.
card_mask
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
noam
=
self
.
noam
,
action_feats
=
self
.
action_feats
,
action_feats
=
self
.
action_feats
,
oppo_info
=
self
.
oppo_info
,
version
=
self
.
version
,
version
=
self
.
version
,
)
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_actions
,
f_state
,
f_g
,
mask
,
valid
=
encoder
(
x
)
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
rnn_layer
=
nn
.
OptimizedLSTMCell
(
...
@@ -594,7 +647,6 @@ class RNNAgent(nn.Module):
...
@@ -594,7 +647,6 @@ class RNNAgent(nn.Module):
elif
self
.
rnn_type
==
'none'
:
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
else
:
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
multi_step
=
num_steps
>
1
...
@@ -607,7 +659,11 @@ class RNNAgent(nn.Module):
...
@@ -607,7 +659,11 @@ class RNNAgent(nn.Module):
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
,
return_state
=
self
.
oppo_info
)
if
self
.
oppo_info
:
f_state_r
,
all_rstate
=
f_state_r
all_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,
x
.
shape
[
-
1
])),
all_rstate
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
else
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rstate
,
f_state_r
=
rnn_step_by_main
(
...
@@ -619,11 +675,29 @@ class RNNAgent(nn.Module):
...
@@ -619,11 +675,29 @@ class RNNAgent(nn.Module):
else
:
else
:
actor
=
Actor
(
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
if
self
.
oppo_info
:
critic
=
GlobalCritic
(
channels
=
[
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
not
multi_step
:
if
isinstance
(
rstate
[
0
],
tuple
):
rstate1_t
,
rstate2_t
=
rstate
else
:
rstate1_t
=
rstate2_t
=
rstate
else
:
main
=
switch_or_main
.
reshape
(
-
1
)[:,
None
]
rstate1
,
rstate2
=
all_rstate
rstate1_t
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x1
,
x2
),
rstate1
,
rstate2
)
rstate2_t
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
else
:
critic
=
Critic
(
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
value
=
critic
(
f_state_r
)
if
self
.
int_head
:
if
self
.
int_head
:
critic_int
=
Critic
(
critic_int
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
@@ -696,7 +770,7 @@ class RNDModel(nn.Module):
...
@@ -696,7 +770,7 @@ class RNDModel(nn.Module):
version
=
self
.
version
,
version
=
self
.
version
,
)
)
f_
actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_
state
=
encoder
(
x
)[
1
]
c
=
f_state
.
shape
[
-
1
]
c
=
f_state
.
shape
[
-
1
]
if
self
.
is_predictor
:
if
self
.
is_predictor
:
predictor
=
MLP
([
oc
,
oc
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
predictor
=
MLP
([
oc
,
oc
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
77492ca0
...
@@ -1527,7 +1527,8 @@ public:
...
@@ -1527,7 +1527,8 @@ public:
"verbose"
_
.
Bind
(
false
),
"max_options"
_
.
Bind
(
16
),
"verbose"
_
.
Bind
(
false
),
"max_options"
_
.
Bind
(
16
),
"max_cards"
_
.
Bind
(
80
),
"n_history_actions"
_
.
Bind
(
16
),
"max_cards"
_
.
Bind
(
80
),
"n_history_actions"
_
.
Bind
(
16
),
"record"
_
.
Bind
(
false
),
"async_reset"
_
.
Bind
(
false
),
"record"
_
.
Bind
(
false
),
"async_reset"
_
.
Bind
(
false
),
"greedy_reward"
_
.
Bind
(
true
),
"timeout"
_
.
Bind
(
600
));
"greedy_reward"
_
.
Bind
(
true
),
"timeout"
_
.
Bind
(
600
),
"oppo_info"
_
.
Bind
(
false
));
}
}
template
<
typename
Config
>
template
<
typename
Config
>
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
static
decltype
(
auto
)
StateSpec
(
const
Config
&
conf
)
{
...
@@ -1539,6 +1540,7 @@ public:
...
@@ -1539,6 +1540,7 @@ public:
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
"obs:h_actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
"obs:g_cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
...
@@ -2259,8 +2261,13 @@ public:
...
@@ -2259,8 +2261,13 @@ public:
}
}
if
(
play_mode_
==
kSelfPlay
)
{
if
(
play_mode_
==
kSelfPlay
)
{
// if (spec_.config["oppo_info"_]) {
if
(
false
)
{
reward
=
winner_
==
0
?
base_reward
:
-
base_reward
;
}
else
{
// to_play_ is the previous player
// to_play_ is the previous player
reward
=
winner_
==
player
?
base_reward
:
-
base_reward
;
reward
=
winner_
==
player
?
base_reward
:
-
base_reward
;
}
}
else
{
}
else
{
reward
=
winner_
==
ai_player_
?
base_reward
:
-
base_reward
;
reward
=
winner_
==
ai_player_
?
base_reward
:
-
base_reward
;
}
}
...
@@ -2331,6 +2338,9 @@ public:
...
@@ -2331,6 +2338,9 @@ public:
}
}
auto
[
spec_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
if
(
spec_
.
config
[
"oppo_info"
_
])
{
_set_obs_g_cards
(
state
[
"obs:g_cards_"
_
]);
}
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
...
@@ -2438,8 +2448,30 @@ private:
...
@@ -2438,8 +2448,30 @@ private:
return
{
spec_infos
,
loc_n_cards
};
return
{
spec_infos
,
loc_n_cards
};
}
}
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
)
{
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
std
::
vector
<
uint8_t
>
configs
=
{
LOCATION_DECK
,
LOCATION_HAND
,
LOCATION_MZONE
,
LOCATION_SZONE
,
LOCATION_GRAVE
,
LOCATION_REMOVED
,
LOCATION_EXTRA
,
};
for
(
auto
location
:
configs
)
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
pi
,
location
);
int
n_cards
=
cards
.
size
();
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
CardId
card_id
=
c_get_card_id
(
c
.
code_
);
_set_obs_card_
(
f_cards
,
offset
,
c
,
false
,
card_id
,
false
);
offset
++
;
}
}
}
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
bool
hide
,
CardId
card_id
=
0
)
{
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
// check offset exceeds max_cards
// check offset exceeds max_cards
uint8_t
location
=
c
.
location_
;
uint8_t
location
=
c
.
location_
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
...
@@ -2462,7 +2494,7 @@ private:
...
@@ -2462,7 +2494,7 @@ private:
seq
=
c
.
sequence_
+
1
;
seq
=
c
.
sequence_
+
1
;
}
}
f_cards
(
offset
,
3
)
=
seq
;
f_cards
(
offset
,
3
)
=
seq
;
f_cards
(
offset
,
4
)
=
(
c
.
controler_
!=
to_play_
)
?
1
:
0
;
f_cards
(
offset
,
4
)
=
global
?
c
.
controler_
:
((
c
.
controler_
!=
to_play_
)
?
1
:
0
)
;
if
(
overlay
)
{
if
(
overlay
)
{
f_cards
(
offset
,
5
)
=
position_to_id
(
POS_FACEUP
);
f_cards
(
offset
,
5
)
=
position_to_id
(
POS_FACEUP
);
f_cards
(
offset
,
6
)
=
1
;
f_cards
(
offset
,
6
)
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment