Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
72a1fd28
Commit
72a1fd28
authored
Jun 19, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add oppo_info
parent
5cd9807d
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2068 additions
and
53 deletions
+2068
-53
scripts/cleanba.py
scripts/cleanba.py
+5
-5
scripts/cleanba_g.py
scripts/cleanba_g.py
+1213
-0
ygoai/rl/env.py
ygoai/rl/env.py
+27
-2
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+42
-35
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+660
-0
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+0
-2
ygoai/rl/utils.py
ygoai/rl/utils.py
+1
-4
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+120
-5
No files found.
scripts/cleanba.py
View file @
72a1fd28
...
@@ -23,9 +23,10 @@ from rich.pretty import pprint
...
@@ -23,9 +23,10 @@ from rich.pretty import pprint
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
EnvPreprocess
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.utils
import
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
...
@@ -356,6 +357,7 @@ def rollout(
...
@@ -356,6 +357,7 @@ def rollout(
args
.
local_env_threads
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
)
envs
=
EnvPreprocess
(
envs
,
skip_mask
=
not
args
.
m1
.
oppo_info
)
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
eval_envs
=
make_env
(
...
@@ -363,6 +365,7 @@ def rollout(
...
@@ -363,6 +365,7 @@ def rollout(
local_seed
+
100000
,
local_seed
+
100000
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
EnvPreprocess
(
eval_envs
,
skip_mask
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
@@ -440,9 +443,6 @@ def rollout(
...
@@ -440,9 +443,6 @@ def rollout(
init_rstates
=
[]
init_rstates
=
[]
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@
jax
.
jit
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
stack
(
xs
),
*
storage
)
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
stack
(
xs
),
*
storage
)
...
@@ -566,7 +566,7 @@ def rollout(
...
@@ -566,7 +566,7 @@ def rollout(
for
x
in
partitioned_storage
:
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
x
=
{
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
if
v
is
not
None
else
None
for
k
,
v
in
x
.
items
()
for
k
,
v
in
x
.
items
()
}
}
elif
x
is
not
None
:
elif
x
is
not
None
:
...
...
scripts/cleanba_g.py
0 → 100644
View file @
72a1fd28
This diff is collapsed.
Click to expand it.
ygoai/rl/env.py
View file @
72a1fd28
...
@@ -60,15 +60,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
...
@@ -60,15 +60,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
class
CompatEnv
(
gym
.
Wrapper
):
class
CompatEnv
(
gym
.
Wrapper
):
def
reset
(
self
,
**
kwargs
):
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
s
uper
()
.
reset
(
**
kwargs
)
observations
,
infos
=
s
elf
.
env
.
reset
(
**
kwargs
)
return
observations
,
infos
return
observations
,
infos
def
step
(
self
,
action
):
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
elf
.
env
.
step
(
action
)
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
uper
()
.
step
(
action
)
dones
=
np
.
logical_or
(
terminated
,
truncated
)
dones
=
np
.
logical_or
(
terminated
,
truncated
)
return
(
return
(
observations
,
observations
,
rewards
,
rewards
,
dones
,
dones
,
infos
,
infos
,
)
class
EnvPreprocess
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
skip_mask
):
super
()
.
__init__
(
env
)
self
.
skip_mask
=
skip_mask
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
self
.
env
.
reset
(
**
kwargs
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
observations
,
infos
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
super
()
.
step
(
action
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
(
observations
,
rewards
,
terminated
,
truncated
,
infos
,
)
)
\ No newline at end of file
ygoai/rl/jax/agent.py
View file @
72a1fd28
...
@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
...
@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
version
:
int
=
0
version
:
int
=
0
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
def
__call__
(
self
,
x_id
,
x
,
mask
):
assert
self
.
version
>
0
c
=
self
.
channels
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
...
@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
...
@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
x_f
=
layer_norm
()(
x_f
)
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards_g
=
None
else
:
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
x_cards
=
jnp
.
concatenate
([
feats_g
=
[
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_id
,
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
]
if
mask
is
not
None
:
assert
len
(
feats_g
)
==
mask
.
shape
[
-
1
]
feats
=
[
jnp
.
where
(
mask
[
...
,
i
:
i
+
1
]
==
1
,
f
,
f
[
...
,
-
1
:,
:])
for
i
,
f
in
enumerate
(
feats_g
)
]
else
:
feats
=
feats_g
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
x_id
x_cards
=
x_cards
*
feats
[
0
]
f_cards
=
layer_norm
()(
x_cards
)
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
if
self
.
oppo_info
:
x_cards_g
=
jnp
.
concatenate
(
feats_g
[
1
:],
axis
=-
1
)
x_cards_g
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards_g
)
x_cards_g
=
x_cards_g
*
feats_g
[
0
]
f_cards_g
=
layer_norm
()(
x_cards_g
)
else
:
f_cards_g
=
None
return
f_cards_g
,
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
class
GlobalEncoder
(
nn
.
Module
):
...
@@ -229,35 +247,26 @@ class Encoder(nn.Module):
...
@@ -229,35 +247,26 @@ class Encoder(nn.Module):
id_embed
=
embed
(
n_embed
,
embed_dim
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards_g
=
x
[
'g_cards_'
]
if
self
.
oppo_info
else
None
x_cards
=
x
[
'cards_'
]
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
mask
=
x
[
'mask_'
]
batch_size
=
x_global
.
shape
[
0
]
batch_size
=
x_global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
valid
=
x_global
[:,
-
1
]
==
0
n_cards
=
x_cards
.
shape
[
-
2
]
if
self
.
oppo_info
:
x_cards
=
jnp
.
concatenate
([
x_cards
,
x_cards_g
],
axis
=-
2
)
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
f_cards
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:])
f_cards_g
,
f_cards_me
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:],
mask
)
if
self
.
oppo_info
:
f_cards_me
,
f_cards_g
=
jnp
.
split
(
f_cards
,
[
n_cards
],
axis
=-
2
)
else
:
f_cards_me
,
f_cards_g
=
f_cards
,
None
# Cards
# Cards
fs_g_card
=
[]
fs_g_card
=
[]
...
@@ -526,19 +535,18 @@ class GlobalCritic(nn.Module):
...
@@ -526,19 +535,18 @@ class GlobalCritic(nn.Module):
channels
:
Sequence
[
int
]
=
(
128
,
128
)
channels
:
Sequence
[
int
]
=
(
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
rstate1
,
rstate2
,
g_cards
):
def
__call__
(
self
,
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
):
f_state
=
jnp
.
concatenate
([
rstate1
[
0
],
rstate1
[
1
],
rstate2
[
0
],
rstate2
[
0
]
],
axis
=-
1
)
f_state
=
jnp
.
concatenate
([
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
],
axis
=-
1
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
c
=
self
.
channels
[
-
1
]
# c = self.channels[-1]
t
=
nn
.
Dense
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
g_cards
)
# t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
s
,
b
=
jnp
.
split
(
t
,
2
,
axis
=-
1
)
# s, b = jnp.split(t, 2, axis=-1)
x
=
x
*
s
+
b
# x = x * s + b
# x = mlp([c], last_lin=False)(x)
x
=
mlp
([
c
],
last_lin
=
False
)(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
return
x
...
@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
...
@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
self
.
oppo_info
:
if
self
.
oppo_info
:
critic
=
GlobalCritic
(
channels
=
[
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
not
multi_step
:
if
not
multi_step
:
if
isinstance
(
rstate
[
0
],
tuple
):
if
isinstance
(
rstate
[
0
],
tuple
):
rstate1_t
,
rstate2_t
=
rstate
rstate1_t
,
rstate2_t
=
rstate
...
@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
...
@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x1
,
x2
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x1
,
x2
),
rstate1
,
rstate2
)
rstate2_t
=
jax
.
tree
.
map
(
rstate2_t
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
f_critic
=
jnp
.
concatenate
([
rstate1_t
[
1
],
rstate2_t
[
1
],
f_state
,
f_g
],
axis
=-
1
)
value
=
critic
(
f_critic
,
train
)
else
:
else
:
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
value
=
critic
(
f_state_r
,
train
)
if
self
.
int_head
:
if
self
.
int_head
:
...
...
ygoai/rl/jax/agent2.py
0 → 100644
View file @
72a1fd28
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/utils.py
View file @
72a1fd28
...
@@ -10,8 +10,6 @@ import optax
...
@@ -10,8 +10,6 @@ import optax
import
numpy
as
np
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
def
masked_mean
(
x
,
valid
):
def
masked_mean
(
x
,
valid
):
x
=
jnp
.
where
(
valid
,
x
,
jnp
.
zeros_like
(
x
))
x
=
jnp
.
where
(
valid
,
x
,
jnp
.
zeros_like
(
x
))
...
...
ygoai/rl/utils.py
View file @
72a1fd28
import
re
import
re
import
numpy
as
np
import
gymnasium
as
gym
import
pickle
import
optree
import
optree
import
torch
import
torch
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
,
EnvPreprocess
def
split_param_groups
(
model
,
regex
):
def
split_param_groups
(
model
,
regex
):
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
72a1fd28
...
@@ -1540,7 +1540,7 @@ public:
...
@@ -1540,7 +1540,7 @@ public:
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
"obs:h_actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
"obs:
g_cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:
mask_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
14
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
...
@@ -2337,9 +2337,18 @@ public:
...
@@ -2337,9 +2337,18 @@ public:
return
;
return
;
}
}
auto
[
spec_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
SpecInfos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
if
(
spec_
.
config
[
"oppo_info"
_
])
{
if
(
spec_
.
config
[
"oppo_info"
_
])
{
_set_obs_g_cards
(
state
[
"obs:g_cards_"
_
]);
_set_obs_g_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec_infos_
,
loc_n_cards_
]
=
_set_obs_mask
(
state
[
"obs:mask_"
_
],
to_play_
);
spec_infos
=
spec_infos_
;
loc_n_cards
=
loc_n_cards_
;
}
else
{
auto
[
spec_infos_
,
loc_n_cards_
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
spec_infos
=
spec_infos_
;
loc_n_cards
=
loc_n_cards_
;
}
}
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
...
@@ -2448,27 +2457,85 @@ private:
...
@@ -2448,27 +2457,85 @@ private:
return
{
spec_infos
,
loc_n_cards
};
return
{
spec_infos
,
loc_n_cards
};
}
}
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
)
{
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
int
offset
=
0
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
std
::
vector
<
uint8_t
>
configs
=
{
std
::
vector
<
uint8_t
>
configs
=
{
LOCATION_DECK
,
LOCATION_HAND
,
LOCATION_MZONE
,
LOCATION_DECK
,
LOCATION_HAND
,
LOCATION_MZONE
,
LOCATION_SZONE
,
LOCATION_GRAVE
,
LOCATION_REMOVED
,
LOCATION_SZONE
,
LOCATION_GRAVE
,
LOCATION_REMOVED
,
LOCATION_EXTRA
,
LOCATION_EXTRA
,
};
};
for
(
auto
location
:
configs
)
{
for
(
auto
location
:
configs
)
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
p
i
,
location
);
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
p
layer
,
location
);
int
n_cards
=
cards
.
size
();
int
n_cards
=
cards
.
size
();
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
const
auto
&
c
=
cards
[
i
];
CardId
card_id
=
c_get_card_id
(
c
.
code_
);
CardId
card_id
=
c_get_card_id
(
c
.
code_
);
_set_obs_card_
(
f_cards
,
offset
,
c
,
false
,
card_id
,
false
);
_set_obs_card_
(
f_cards
,
offset
,
c
,
false
,
card_id
,
false
);
offset
++
;
offset
++
;
if
(
offset
==
(
spec_
.
config
[
"max_cards"
_
]
*
2
-
1
))
{
return
;
}
}
}
}
}
}
}
}
}
std
::
tuple
<
SpecInfos
,
std
::
vector
<
int
>>
_set_obs_mask
(
TArray
<
uint8_t
>
&
mask
,
PlayerId
to_play
)
{
SpecInfos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
bool
opponent
=
pi
==
1
;
std
::
vector
<
std
::
pair
<
uint8_t
,
bool
>>
configs
=
{
{
LOCATION_DECK
,
true
},
{
LOCATION_HAND
,
true
},
{
LOCATION_MZONE
,
false
},
{
LOCATION_SZONE
,
false
},
{
LOCATION_GRAVE
,
false
},
{
LOCATION_REMOVED
,
false
},
{
LOCATION_EXTRA
,
true
},
};
for
(
auto
&
[
location
,
hidden_for_opponent
]
:
configs
)
{
// check this
if
(
opponent
&&
(
revealed_
.
size
()
!=
0
))
{
hidden_for_opponent
=
false
;
}
if
(
opponent
&&
hidden_for_opponent
)
{
auto
n_cards
=
YGO_QueryFieldCount
(
pduel_
,
player
,
location
);
loc_n_cards
.
push_back
(
n_cards
);
for
(
auto
i
=
0
;
i
<
n_cards
;
i
++
)
{
mask
(
offset
,
1
)
=
1
;
mask
(
offset
,
3
)
=
1
;
offset
++
;
}
}
else
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
player
,
location
);
int
n_cards
=
cards
.
size
();
loc_n_cards
.
push_back
(
n_cards
);
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
auto
spec
=
c
.
get_spec
(
opponent
);
bool
hide
=
false
;
if
(
opponent
)
{
hide
=
c
.
position_
&
POS_FACEDOWN
;
if
(
revealed_
.
find
(
spec
)
!=
revealed_
.
end
())
{
hide
=
false
;
}
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
_set_obs_mask_
(
mask
,
offset
,
c
,
hide
);
offset
++
;
spec_infos
[
spec
]
=
{
static_cast
<
uint16_t
>
(
offset
),
card_id
};
}
}
}
}
return
{
spec_infos
,
loc_n_cards
};
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
...
@@ -2531,6 +2598,54 @@ private:
...
@@ -2531,6 +2598,54 @@ private:
}
}
}
}
void
_set_obs_mask_
(
TArray
<
uint8_t
>
&
mask
,
int
offset
,
const
Card
&
c
,
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
// check offset exceeds max_cards
uint8_t
location
=
c
.
location_
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
if
(
overlay
)
{
location
=
location
&
0x7f
;
}
if
(
overlay
)
{
hide
=
false
;
}
if
(
!
hide
)
{
if
(
card_id
!=
0
)
{
mask
(
offset
,
0
)
=
1
;
}
}
mask
(
offset
,
1
)
=
1
;
if
(
location
==
LOCATION_MZONE
||
location
==
LOCATION_SZONE
||
location
==
LOCATION_GRAVE
)
{
mask
(
offset
,
2
)
=
1
;
}
mask
(
offset
,
3
)
=
1
;
if
(
overlay
)
{
mask
(
offset
,
4
)
=
1
;
mask
(
offset
,
5
)
=
1
;
}
else
{
if
(
location
==
LOCATION_DECK
||
location
==
LOCATION_HAND
||
location
==
LOCATION_EXTRA
)
{
if
(
hide
||
(
c
.
position_
&
POS_FACEDOWN
))
{
mask
(
offset
,
4
)
=
1
;
}
}
else
{
mask
(
offset
,
4
)
=
1
;
}
}
if
(
!
hide
)
{
mask
(
offset
,
6
)
=
1
;
mask
(
offset
,
7
)
=
1
;
mask
(
offset
,
8
)
=
1
;
mask
(
offset
,
9
)
=
1
;
mask
(
offset
,
10
)
=
1
;
mask
(
offset
,
11
)
=
1
;
mask
(
offset
,
12
)
=
1
;
mask
(
offset
,
13
)
=
1
;
}
}
void
_set_obs_global
(
TArray
<
uint8_t
>
&
feat
,
PlayerId
player
,
const
std
::
vector
<
int
>
&
loc_n_cards
)
{
void
_set_obs_global
(
TArray
<
uint8_t
>
&
feat
,
PlayerId
player
,
const
std
::
vector
<
int
>
&
loc_n_cards
)
{
uint8_t
me
=
player
;
uint8_t
me
=
player
;
uint8_t
op
=
1
-
player
;
uint8_t
op
=
1
-
player
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment