Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
72a1fd28
Commit
72a1fd28
authored
Jun 19, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add oppo_info
parent
5cd9807d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2068 additions
and
53 deletions
+2068
-53
scripts/cleanba.py
scripts/cleanba.py
+5
-5
scripts/cleanba_g.py
scripts/cleanba_g.py
+1213
-0
ygoai/rl/env.py
ygoai/rl/env.py
+27
-2
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+42
-35
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+660
-0
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+0
-2
ygoai/rl/utils.py
ygoai/rl/utils.py
+1
-4
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+120
-5
No files found.
scripts/cleanba.py
View file @
72a1fd28
...
...
@@ -23,9 +23,10 @@ from rich.pretty import pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
EnvPreprocess
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.utils
import
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
...
...
@@ -356,6 +357,7 @@ def rollout(
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
EnvPreprocess
(
envs
,
skip_mask
=
not
args
.
m1
.
oppo_info
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
...
...
@@ -363,6 +365,7 @@ def rollout(
local_seed
+
100000
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
EnvPreprocess
(
eval_envs
,
skip_mask
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -440,9 +443,6 @@ def rollout(
init_rstates
=
[]
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
stack
(
xs
),
*
storage
)
...
...
@@ -566,7 +566,7 @@ def rollout(
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
if
v
is
not
None
else
None
for
k
,
v
in
x
.
items
()
}
elif
x
is
not
None
:
...
...
scripts/cleanba_g.py
0 → 100644
View file @
72a1fd28
import
os
import
shutil
import
queue
import
random
import
threading
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
,
asdict
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
,
Literal
from
functools
import
partial
import
ygoenv
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
optax
import
distrax
import
tyro
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
EnvPreprocess
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
as
RNNAgentE
,
ModelArgs
as
ModelArgsE
from
ygoai.rl.jax.agent2
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
ach_loss
,
policy_gradient_loss
,
vtrace
,
vtrace_sep
,
truncated_gae
,
truncated_gae_sep
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
.
rstrip
(
".py"
)
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq
:
int
=
0
"""the logging frequency of the deck time statistics, 0 to disable"""
save_interval
:
int
=
400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
timeout
:
int
=
600
"""the timeout of the environment step"""
debug
:
bool
=
False
"""whether to run the script in debug mode"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
tb_offset
:
int
=
0
"""the step offset of the tensorboard logs"""
run_name
:
Optional
[
str
]
=
None
"""the name of the tensorboard run"""
ckpt_dir
:
str
=
"checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket
:
Optional
[
str
]
=
None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v1"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
greedy_reward
:
bool
=
False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
3e-4
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
local_env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
collect_steps
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
segment_length
:
Optional
[
int
]
=
None
"""the length of the segment for training"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
num_minibatches
:
int
=
64
"""the number of mini-batches"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
norm_adv
:
bool
=
False
"""Toggles advantages normalization"""
burn_in_steps
:
Optional
[
int
]
=
None
"""the number of burn-in steps for training (for R2D2)"""
upgo
:
bool
=
True
"""Toggle the use of UPGO for advantages"""
sep_value
:
bool
=
True
"""Whether separate value function computation for each player"""
value
:
Literal
[
"vtrace"
,
"gae"
]
=
"vtrace"
"""the method to learn the value function"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
c_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
"""the PPO surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
3.0
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold
:
Optional
[
float
]
=
None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
vloss_clip
:
Optional
[
float
]
=
None
"""the value loss clipping coefficient"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
1.0
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
m1
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the model arguments for the agent"""
m2
:
ModelArgsE
=
field
(
default_factory
=
lambda
:
ModelArgsE
())
"""the model arguments for the eval agent"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
learner_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
2
,
3
])
"""the device ids that learner workers will use"""
distributed
:
bool
=
False
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
False
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
batch_size
:
int
=
0
minibatch_size
:
int
=
0
num_updates
:
int
=
0
global_learner_decices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
Optional
[
bool
]
=
None
deck_names
:
Optional
[
List
[
str
]]
=
None
real_seed
:
Optional
[
int
]
=
None
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
print
(
"Binding to thread offset"
,
thread_affinity_offset
)
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
thread_affinity_offset
=
thread_affinity_offset
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
timeout
=
args
.
timeout
,
oppo_info
=
False
if
eval
else
True
,
)
envs
.
num_envs
=
num_envs
return
envs
class
Transition
(
NamedTuple
):
obs
:
list
dones
:
list
actions
:
list
logits
:
list
rewards
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
actor
=
False
,
eval
=
False
):
if
eval
:
return
RNNAgentE
(
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
**
asdict
(
args
.
m2
),
)
else
:
if
actor
==
'eval'
:
actor
=
True
eval
=
True
else
:
eval
=
False
return
RNNAgent
(
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
freeze_id
=
args
.
freeze_id
,
actor
=
actor
,
eval
=
eval
,
**
asdict
(
args
.
m1
),
)
def
get_variables
(
agent_state
):
batch_stats
=
getattr
(
agent_state
,
"batch_stats"
,
None
)
variables
=
{
'params'
:
agent_state
.
params
}
if
batch_stats
is
not
None
:
variables
[
'batch_stats'
]
=
batch_stats
return
variables
def
get_actor_variables
(
agent_state
):
batch_stats
=
getattr
(
agent_state
,
"batch_stats"
,
None
)
variables
=
{
'params'
:
agent_state
.
params
[
"actor"
]}
if
batch_stats
is
not
None
:
variables
[
'batch_stats'
]
=
batch_stats
[
"actor"
]
return
variables
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
def
reshape_minibatch
(
x
,
multi_step
,
num_minibatches
,
num_steps
,
segment_length
=
None
,
key
=
None
):
# if segment_length is None,
# n_mb = num_minibatches
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb, num_steps * (num_envs // n_mb), ...)
# else, from (num_envs, ...) to
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // n_mb_t
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
# (n_mb_e, num_envs // n_mb_e, ...)
if
key
is
not
None
:
x
=
jax
.
random
.
permutation
(
key
,
x
,
axis
=
1
if
multi_step
else
0
)
N
=
num_minibatches
if
segment_length
is
None
:
if
multi_step
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
else
:
M
=
segment_length
Nt
=
num_steps
//
M
Ne
=
N
//
Nt
if
multi_step
:
x
=
jnp
.
reshape
(
x
,
(
Nt
,
M
,
Ne
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
2
,
0
,
1
,
*
range
(
3
,
x
.
ndim
))
x
=
jnp
.
reshape
(
x
,
(
Ne
,
Nt
,
-
1
)
+
x
.
shape
[
4
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
Ne
,
-
1
)
+
x
.
shape
[
1
:])
return
x
def
advantage_fn
(
args
,
next_v
,
values
,
rewards
,
next_dones
,
mains
,
ratios
=
None
,
return_carry
=
False
):
# TODO: TD(lambda) for multi-step
if
args
.
value
==
"gae"
:
adv_fn
=
truncated_gae_sep
if
args
.
sep_value
else
truncated_gae
return
adv_fn
(
next_v
,
values
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
,
return_carry
=
return_carry
)
else
:
adv_fn
=
vtrace_sep
if
args
.
sep_value
else
vtrace
if
ratios
is
None
:
ratios
=
jnp
.
ones_like
(
values
)
return
adv_fn
(
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
,
args
.
upgo
,
return_carry
=
return_carry
)
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
rollout_queue
,
params_queue
,
writer
,
actor_device
,
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
real_seed
+
device_thread_id
*
args
.
local_num_envs
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
args
,
local_seed
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
EnvPreprocess
(
envs
,
skip_mask
=
False
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
args
,
local_seed
+
100000
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
EnvPreprocess
(
eval_envs
,
skip_mask
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
n_actors
=
args
.
num_actor_threads
*
len_actor_device_ids
global_step
=
0
start_time
=
time
.
time
()
warmup_step
=
0
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
actor
=
create_agent
(
args
,
actor
=
True
)
eval_agent1
=
create_agent
(
args
,
actor
=
'eval'
)
eval_agent2
=
create_agent
(
args
,
eval
=
eval_mode
!=
'bot'
)
@
jax
.
jit
def
get_action
(
params
,
obs
,
rstate
):
rstate
,
logits
=
eval_agent1
.
apply
(
params
,
obs
,
rstate
)[:
2
]
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
obs
,
rstate1
,
rstate2
,
main
,
done
):
next_rstate1
,
logits1
=
eval_agent1
.
apply
(
params1
,
obs
,
rstate1
)[:
2
]
next_rstate2
,
logits2
=
eval_agent2
.
apply
(
params2
,
obs
,
rstate2
)[:
2
]
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
(
rstate1
,
rstate2
),
logits
=
actor
.
apply
(
params
,
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
@
jax
.
jit
def
compute_advantage_carry
(
next_value
,
values
,
rewards
,
next_dones
,
mains
):
return
advantage_fn
(
args
,
next_value
,
values
,
rewards
,
next_dones
,
mains
,
return_carry
=
True
)
deck_names
=
args
.
deck_names
deck_avg_times
=
{
name
:
0
for
name
in
deck_names
}
deck_max_times
=
{
name
:
0
for
name
in
deck_names
}
deck_time_count
=
{
name
:
0
for
name
in
deck_names
}
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
actor
.
init_rnn_state
(
args
.
local_num_envs
)
eval_rstate1
=
eval_agent1
.
init_rnn_state
(
args
.
local_eval_episodes
)
eval_rstate2
=
eval_agent2
.
init_rnn_state
(
args
.
local_eval_episodes
)
next_rstate1
,
next_rstate2
,
eval_rstate1
,
eval_rstate2
=
\
jax
.
device_put
([
next_rstate1
,
next_rstate2
,
eval_rstate1
,
eval_rstate2
],
actor_device
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
init_rstates
=
[]
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
stack
(
xs
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
start_time
=
time
.
time
()
warmup_step
=
global_step
update_time_start
=
time
.
time
()
inference_time
=
0
env_time
=
0
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
update
!=
2
:
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
actor_policy_version
+=
1
else
:
params
=
params_queue
.
get
()
actor_policy_version
+=
1
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
for
k
in
range
(
start_step
,
args
.
collect_steps
):
if
k
%
args
.
num_steps
==
0
:
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
init_rstates
.
append
((
init_rstate1
,
init_rstate2
))
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
next_obs
,
next_rstate1
,
next_rstate2
,
main
,
next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
storage
.
append
(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
cached_main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
continue
cur_main
=
main
[
idx
]
if
args
.
time_log_freq
:
for
i
in
range
(
2
):
deck_time
=
info
[
'step_time'
][
idx
][
i
]
deck_name
=
deck_names
[
info
[
'deck'
][
idx
][
i
]]
time_count
=
deck_time_count
[
deck_name
]
avg_time
=
deck_avg_times
[
deck_name
]
avg_time
=
avg_time
*
(
time_count
/
(
time_count
+
1
))
+
deck_time
/
(
time_count
+
1
)
max_time
=
max
(
deck_time
,
deck_max_times
[
deck_name
])
deck_avg_times
[
deck_name
]
=
avg_time
deck_max_times
[
deck_name
]
=
max_time
deck_time_count
[
deck_name
]
+=
1
if
deck_time_count
[
deck_name
]
%
args
.
time_log_freq
==
0
:
print
(
f
"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}"
)
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_steps
-
args
.
num_steps
next_main
=
main_player
==
next_to_play
if
args
.
collect_steps
==
args
.
num_steps
:
storage_t
=
storage
storage
=
[]
next_data
=
(
next_obs
,
next_main
)
else
:
storage_t
=
storage
[:
args
.
num_steps
]
storage
=
storage
[
args
.
num_steps
:]
values
,
rewards
,
next_dones
,
mains
=
prepare_data
([
(
t
.
values
,
t
.
rewards
,
t
.
next_dones
,
t
.
mains
)
for
t
in
storage
])
next_value
=
sample_action
(
params
,
next_obs
,
next_rstate1
,
next_rstate2
,
next_main
,
next_done
,
key
)[
-
2
]
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
adv_carry
=
compute_advantage_carry
(
next_value
,
values
,
rewards
,
next_dones
,
mains
)
next_data
=
adv_carry
partitioned_storage
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
split
(
x
,
len
(
learner_devices
),
axis
=
1
),
prepare_data
(
storage_t
))
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
if
v
is
not
None
else
None
for
k
,
v
in
x
.
items
()
}
elif
x
is
not
None
:
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
init_rstate
=
init_rstates
.
pop
(
0
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate
,
next_data
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
*
x
:
get_action
(
params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate2
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate1
,
eval_rstate2
)
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
else
:
eval_stats
=
None
payload
=
(
global_step
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
eval_stats
,
)
rollout_queue
.
put
(
payload
)
if
update
%
args
.
log_frequency
==
0
:
avg_episodic_return
=
np
.
mean
(
avg_ep_returns
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
tb_global_step
=
args
.
tb_offset
+
global_step
if
device_thread_id
==
0
:
print
(
f
"global_step={tb_global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
tb_global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
tb_global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
tb_global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
tb_global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
tb_global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
tb_global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
tb_global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
tb_global_step
)
def
main
():
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
assert
(
args
.
local_num_envs
%
len
(
args
.
learner_device_ids
)
==
0
),
"local_num_envs must be divisible by len(learner_device_ids)"
assert
(
int
(
args
.
local_num_envs
/
len
(
args
.
learner_device_ids
))
*
args
.
num_actor_threads
%
args
.
num_minibatches
==
0
),
"int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if
args
.
distributed
:
jax
.
distributed
.
initialize
(
local_device_ids
=
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
)),
)
print
(
list
(
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
))))
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
.
world_size
=
jax
.
process_count
()
args
.
local_rank
=
jax
.
process_index
()
args
.
num_envs
=
args
.
local_num_envs
*
args
.
world_size
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
)
args
.
batch_size
=
args
.
local_batch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
segment_length
is
not
None
:
assert
args
.
num_steps
%
args
.
segment_length
==
0
,
"num_steps must be divisible by segment_length"
args
.
collect_steps
=
args
.
collect_steps
or
args
.
num_steps
assert
args
.
collect_steps
>=
args
.
num_steps
,
"collect_steps must be greater than or equal to num_steps"
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
learner_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
learner_device_ids
]
actor_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
actor_device_ids
]
global_learner_decices
=
[
global_devices
[
d_id
+
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
global_main_devices
=
[
global_devices
[
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
if
args
.
run_name
is
None
:
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.exp_name}__{args.seed}__{timestamp}"
else
:
run_name
=
args
.
run_name
timestamp
=
int
(
run_name
.
split
(
"__"
)[
-
1
])
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
and
not
args
.
debug
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
else
:
writer
=
dummy_writer
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
obj
))
ckpt_maneger
=
ModelCheckpoint
(
args
.
ckpt_dir
,
save_fn
,
n_saved
=
2
)
# seeding
random
.
seed
(
args
.
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
seed_offset
=
args
.
local_rank
seed
+=
seed_offset
init_key
=
jax
.
random
.
PRNGKey
(
seed
-
seed_offset
)
random
.
seed
(
seed
)
args
.
real_seed
=
random
.
randint
(
0
,
1e8
)
key
=
jax
.
random
.
PRNGKey
(
args
.
real_seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
deck
,
deck_names
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
,
return_deck_names
=
True
)
args
.
deck_names
=
sorted
(
deck_names
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
0
,
2
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
envs
.
close
()
del
envs
def
linear_schedule
(
count
):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
actor
=
create_agent
(
args
,
actor
=
True
)
actor_variables
=
actor
.
init
(
init_key
,
sample_obs
,
actor
.
init_rnn_state
(
1
))
actor_variables
=
flax
.
core
.
unfreeze
(
actor_variables
)
critic
=
create_agent
(
args
,
actor
=
False
)
critic_variables
=
critic
.
init
(
init_key
,
sample_obs
,
critic
.
init_rnn_state
(
1
))
critic_variables
=
flax
.
core
.
unfreeze
(
critic_variables
)
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
for
v
in
[
actor_variables
,
critic_variables
]:
v
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
variables
=
flax
.
serialization
.
from_bytes
(
variables
,
f
.
read
())
actor_variables
,
critic_variables
=
variables
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
inject_hyperparams
(
optax
.
adam
)(
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
1
,
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
10
)
for
v
in
[
actor_variables
,
critic_variables
]:
if
'batch_stats'
not
in
v
:
v
[
'batch_stats'
]
=
{}
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
{
"actor"
:
actor_variables
[
'params'
],
"critic"
:
critic_variables
[
'params'
]},
tx
=
tx
,
batch_stats
=
{
"actor"
:
actor_variables
[
'batch_stats'
],
"critic"
:
critic_variables
[
'batch_stats'
]},
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_variables
=
eval_agent
.
init
(
init_key
,
sample_obs
,
eval_rstate
)
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_variables
=
flax
.
serialization
.
from_bytes
(
eval_variables
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_variables
=
None
def
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
mains
,
actions
,
logits
,
rewards
,
next_v
):
num_envs
=
jax
.
tree
.
leaves
(
next_v
)[
0
]
.
shape
[
0
]
num_steps
=
next_dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
ratios
=
reshape_time_series
(
ratios
)
new_values_
,
rewards
,
next_dones
,
mains
=
jax
.
tree
.
map
(
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
mains
),
)
target_values
,
advantages
=
advantage_fn
(
args
,
next_v
,
new_values_
,
rewards
,
next_dones
,
mains
,
ratios
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
):
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(
ratios
-
1
)
-
logratio
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
elif
args
.
logits_threshold
is
not
None
:
pg_loss
=
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
args
.
logits_threshold
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
elif
args
.
ppo_clip
:
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
else
:
pg_advs
=
jnp
.
clip
(
ratios
,
args
.
rho_clip_min
,
args
.
rho_clip_max
)
*
advantages
pg_loss
=
policy_gradient_loss
(
new_logits
,
actions
,
pg_advs
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
if
args
.
vloss_clip
is
not
None
:
v_loss
=
jnp
.
minimum
(
v_loss
,
args
.
vloss_clip
)
ent_loss
=
entropy_loss
(
new_logits
)
if
args
.
burn_in_steps
:
mask
=
jax
.
tree
.
map
(
lambda
x
:
x
.
reshape
(
num_steps
,
-
1
),
mask
)
burn_in_mask
=
jnp
.
arange
(
num_steps
)
<
args
.
burn_in_steps
mask
=
jnp
.
where
(
burn_in_mask
[:,
None
],
0.0
,
mask
)
mask
=
jnp
.
reshape
(
mask
,
(
-
1
,))
n_valids
=
jnp
.
sum
(
mask
)
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
sum
(
x
*
mask
)
/
n_valids
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
))
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
def
apply_fn
(
actor_variables
,
critic_variables
,
obs
,
init_rstate
,
dones
,
mains
,
train
=
True
):
mutable
=
[
"batch_stats"
]
if
train
else
False
rets_a
=
actor
.
apply
(
actor_variables
,
obs
,
init_rstate
,
dones
,
mains
,
train
=
train
,
mutable
=
mutable
)
# TODO: improve this
rstate
=
critic
.
init_rnn_state
(
jax
.
tree
.
leaves
(
init_rstate
)[
0
]
.
shape
[
0
])
rets_c
=
critic
.
apply
(
critic_variables
,
obs
,
rstate
,
dones
,
train
=
train
,
mutable
=
mutable
)
if
train
:
((
rstate1
,
rstate2
),
new_logits
),
state_updates_a
=
rets_a
(
rstate
,
new_values
),
state_updates_c
=
rets_c
state_updates
=
{
"batch_stats"
:
{
"actor"
:
state_updates_a
[
"batch_stats"
],
"critic"
:
state_updates_c
[
"batch_stats"
],
}}
else
:
(
rstate1
,
rstate2
),
new_logits
=
rets_a
rstate
,
new_values
=
rets_c
state_updates
=
{}
new_values
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
new_values
)
return
((
rstate
,
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
def
compute_next_value
(
variables
,
next_rstate
,
next_obs
,
next_main
):
next_value
=
critic
.
apply
(
variables
,
next_obs
,
next_rstate
)[
1
]
next_value
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
next_value
)
next_value
=
jax
.
lax
.
stop_gradient
(
next_value
)
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
return
next_value
def
get_advantage
(
variables
,
init_rstate
,
obs
,
dones
,
next_dones
,
mains
,
actions
,
logits
,
rewards
,
next_obs
,
next_main
):
num_steps
=
dones
.
shape
[
0
]
obs
,
dones
,
next_dones
,
mains
,
actions
,
logits
,
rewards
=
\
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
(
obs
,
dones
,
next_dones
,
mains
,
actions
,
logits
,
rewards
))
(
next_rstate
,
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
mains
,
train
=
False
)
next_value
=
compute_next_value
(
variables
,
next_rstate
[
0
],
next_obs
,
next_main
)
target_values
,
advantages
=
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
mains
,
actions
,
logits
,
rewards
,
next_value
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
-
1
)
+
x
.
shape
[
2
:]),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
get_loss
(
params
,
batch_stats
,
init_rstate
,
obs
,
dones
,
next_dones
,
mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
mains
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
)
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
approx_kl
,
rstate1
,
rstate2
=
jax
.
tree
.
map
(
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
def
get_advantage_loss
(
params
,
batch_stats
,
init_rstate
,
obs
,
dones
,
next_dones
,
mains
,
actions
,
logits
,
rewards
,
mask
,
next_data
):
num_envs
=
jax
.
tree
.
leaves
(
next_data
)[
0
]
.
shape
[
0
]
actor_variables
=
{
'params'
:
params
[
"actor"
],
'batch_stats'
:
batch_stats
[
"actor"
]}
critic_variables
=
{
'params'
:
params
[
"critic"
],
'batch_stats'
:
batch_stats
[
"critic"
]}
(
next_rstate
,
new_logits
,
new_values
),
state_updates
=
apply_fn
(
actor_variables
,
critic_variables
,
obs
,
init_rstate
,
dones
,
mains
)
if
args
.
collect_steps
==
args
.
num_steps
:
next_obs
,
next_main
=
next_data
variables
=
{
'params'
:
params
[
"critic"
],
'batch_stats'
:
state_updates
[
'batch_stats'
][
"critic"
]}
next_v
=
compute_next_value
(
variables
,
next_rstate
[
0
],
next_obs
,
next_main
)
else
:
next_v
=
next_data
target_values
,
advantages
=
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
mains
,
actions
,
logits
,
rewards
,
next_v
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
dones
.
shape
[
0
]
//
num_envs
)
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
approx_kl
=
jax
.
lax
.
stop_gradient
(
approx_kl
)
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
def
single_device_update
(
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_init_rstate
:
List
,
sharded_next_data
:
List
,
key
:
jax
.
random
.
PRNGKey
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_data
,
init_rstate
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_data
,
sharded_init_rstate
]
]
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
if
args
.
segment_length
is
None
:
loss_grad_fn
=
jax
.
value_and_grad
(
get_advantage_loss
,
has_aux
=
True
)
else
:
# TODO: fix it
loss_grad_fn
=
jax
.
value_and_grad
(
get_loss
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
def
convert_data
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
key
=
subkey
if
args
.
update_epochs
>
1
else
None
return
reshape_minibatch
(
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
b_init_rstate
,
b_next_data
=
\
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate
,
next_data
))
b_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
b_mask
=
~
b_storage
.
dones
b_rewards
=
b_storage
.
rewards
if
args
.
segment_length
is
None
:
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
\
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stats
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
else
:
def
update_minibatch
(
carry
,
minibatch
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
agent_state
,
init_rstate
=
carry
minibatch_t
=
init_rstate
,
*
minibatch_t
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
next_rstate
)),
\
grads
=
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stats
,
*
minibatch_t
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
return
(
agent_state
,
next_rstate
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
init_rstate
,
*
minibatch_t
,
mask
=
minibatch
target_values
,
advantages
=
get_advantage
(
get_variables
(
carry
),
init_rstate
,
*
minibatch_t
)
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
(
carry
,
_next_rstate
),
\
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch_t
,
(
carry
,
init_rstate
),
minibatch_t
)
return
carry
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
b_init_rstate
,
b_storage
.
obs
,
b_storage
.
dones
,
b_storage
.
next_dones
,
b_storage
.
mains
,
b_storage
.
actions
,
b_storage
.
logits
,
b_rewards
,
b_mask
,
b_next_data
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
ent_loss
=
jax
.
lax
.
pmean
(
ent_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
key
all_reduce_value
=
jax
.
pmap
(
lambda
x
:
jax
.
lax
.
pmean
(
x
,
axis_name
=
"main_devices"
),
axis_name
=
"main_devices"
,
devices
=
global_main_devices
,
)
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
)
params_queues
=
[]
rollout_queues
=
[]
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
get_actor_variables
(
agent_state
))
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
actor_device
=
local_devices
[
d_id
]
device_params
=
jax
.
device_put
(
unreplicated_params
,
actor_device
)
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_variables
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_variables
,
actor_device
))
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
actor_keys
[
actor_thread_id
],
actor_device
),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
actor_device
,
learner_devices
,
actor_thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
eval_stat_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
eval_stats
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
if
eval_stats
is
not
None
:
eval_stat_list
.
append
(
eval_stats
)
tb_global_step
=
args
.
tb_offset
+
global_step
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
tb_global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
tb_global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
get_actor_variables
(
agent_state
))
params_queue_put_time
=
0
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
params_queue_put_start
=
time
.
time
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
params_queue_put_time
+=
time
.
time
()
-
params_queue_put_start
loss
=
loss
[
-
1
]
.
item
()
if
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
):
raise
ValueError
(
f
"loss is {loss}"
)
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
tb_global_step
)
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
tb_global_step
,
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
tb_global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
tb_global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
tb_global_step
)
print
(
f
"{tb_global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}, "
f
"put_time={params_queue_put_time:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
ent_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
tb_global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
tb_global_step
//
2
**
20
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
lastest_path
=
ckpt_maneger
.
get_latest
()
copy_path
=
lastest_path
.
with_name
(
"latest"
+
lastest_path
.
suffix
)
shutil
.
copyfile
(
lastest_path
,
copy_path
)
zip_file_path
=
"latest.zip"
zip_files
(
zip_file_path
,
[
str
(
copy_path
),
tb_log_dir
])
sync_to_gcs
(
args
.
gcs_bucket
,
zip_file_path
)
if
learner_policy_version
>=
args
.
num_updates
:
break
if
args
.
distributed
:
jax
.
distributed
.
shutdown
()
writer
.
close
()
if
__name__
==
"__main__"
:
main
()
ygoai/rl/env.py
View file @
72a1fd28
...
...
@@ -60,11 +60,11 @@ class RecordEpisodeStatistics(gym.Wrapper):
class
CompatEnv
(
gym
.
Wrapper
):
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
s
uper
()
.
reset
(
**
kwargs
)
observations
,
infos
=
s
elf
.
env
.
reset
(
**
kwargs
)
return
observations
,
infos
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
elf
.
env
.
step
(
action
)
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
uper
()
.
step
(
action
)
dones
=
np
.
logical_or
(
terminated
,
truncated
)
return
(
observations
,
...
...
@@ -72,3 +72,28 @@ class CompatEnv(gym.Wrapper):
dones
,
infos
,
)
class
EnvPreprocess
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
skip_mask
):
super
()
.
__init__
(
env
)
self
.
skip_mask
=
skip_mask
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
self
.
env
.
reset
(
**
kwargs
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
observations
,
infos
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
super
()
.
step
(
action
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
(
observations
,
rewards
,
terminated
,
truncated
,
infos
,
)
\ No newline at end of file
ygoai/rl/jax/agent.py
View file @
72a1fd28
...
...
@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
def
__call__
(
self
,
x_id
,
x
,
mask
):
assert
self
.
version
>
0
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
...
...
@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards_g
=
None
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
x_cards
=
jnp
.
concatenate
([
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
feats_g
=
[
x_id
,
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
]
if
mask
is
not
None
:
assert
len
(
feats_g
)
==
mask
.
shape
[
-
1
]
feats
=
[
jnp
.
where
(
mask
[
...
,
i
:
i
+
1
]
==
1
,
f
,
f
[
...
,
-
1
:,
:])
for
i
,
f
in
enumerate
(
feats_g
)
]
else
:
feats
=
feats_g
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
x_id
x_cards
=
x_cards
*
feats
[
0
]
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
if
self
.
oppo_info
:
x_cards_g
=
jnp
.
concatenate
(
feats_g
[
1
:],
axis
=-
1
)
x_cards_g
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards_g
)
x_cards_g
=
x_cards_g
*
feats_g
[
0
]
f_cards_g
=
layer_norm
()(
x_cards_g
)
else
:
f_cards_g
=
None
return
f_cards_g
,
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
...
...
@@ -229,35 +247,26 @@ class Encoder(nn.Module):
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards_g
=
x
[
'g_cards_'
]
if
self
.
oppo_info
else
None
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
mask
=
x
[
'mask_'
]
batch_size
=
x_global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
n_cards
=
x_cards
.
shape
[
-
2
]
if
self
.
oppo_info
:
x_cards
=
jnp
.
concatenate
([
x_cards
,
x_cards_g
],
axis
=-
2
)
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
f_cards
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:])
if
self
.
oppo_info
:
f_cards_me
,
f_cards_g
=
jnp
.
split
(
f_cards
,
[
n_cards
],
axis
=-
2
)
else
:
f_cards_me
,
f_cards_g
=
f_cards
,
None
f_cards_g
,
f_cards_me
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:],
mask
)
# Cards
fs_g_card
=
[]
...
...
@@ -528,17 +537,16 @@ class GlobalCritic(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
rstate1
,
rstate2
,
g_cards
):
f_state
=
jnp
.
concatenate
([
rstate1
[
0
],
rstate1
[
1
],
rstate2
[
0
],
rstate2
[
0
]
],
axis
=-
1
)
def
__call__
(
self
,
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
):
f_state
=
jnp
.
concatenate
([
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
],
axis
=-
1
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
c
=
self
.
channels
[
-
1
]
t
=
nn
.
Dense
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
g_cards
)
s
,
b
=
jnp
.
split
(
t
,
2
,
axis
=-
1
)
x
=
x
*
s
+
b
x
=
mlp
([
c
],
last_lin
=
False
)(
x
)
# c = self.channels[-1]
# t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
# s, b = jnp.split(t, 2, axis=-1)
# x = x * s + b
# x = mlp([c], last_lin=False)(x)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
...
...
@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
self
.
oppo_info
:
critic
=
GlobalCritic
(
channels
=
[
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
not
multi_step
:
if
isinstance
(
rstate
[
0
],
tuple
):
rstate1_t
,
rstate2_t
=
rstate
...
...
@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x1
,
x2
),
rstate1
,
rstate2
)
rstate2_t
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
f_critic
=
jnp
.
concatenate
([
rstate1_t
[
1
],
rstate2_t
[
1
],
f_state
,
f_g
],
axis
=-
1
)
value
=
critic
(
f_critic
,
train
)
else
:
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
if
self
.
int_head
:
...
...
ygoai/rl/jax/agent2.py
0 → 100644
View file @
72a1fd28
from
dataclasses
import
dataclass
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
,
Literal
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
from
ygoai.rl.jax.transformer
import
EncoderLayer
,
PositionalEncoding
,
LlamaEncoderLayer
from
ygoai.rl.jax.modules
import
MLP
,
GLUMlp
,
BatchRenorm
,
make_bin_params
,
bytes_to_bin
,
decode_id
from
ygoai.rl.jax.rwkv
import
Rwkv6SelfAttention
default_embed_init
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init1
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
default_fc_init2
=
nn
.
initializers
.
uniform
(
scale
=
0.001
)
def
get_encoder_layer_cls
(
noam
,
n_heads
,
dtype
,
param_dtype
):
if
noam
:
return
LlamaEncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
,
rope
=
False
)
else
:
return
EncoderLayer
(
n_heads
,
dtype
=
dtype
,
param_dtype
=
param_dtype
)
class
ActionEncoderV1
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
div
=
8
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
x_a_msg
=
embed
(
30
,
c
//
div
)(
x
[:,
:,
0
])
x_a_act
=
embed
(
10
,
c
//
div
)(
x
[:,
:,
1
])
x_a_finish
=
embed
(
3
,
c
//
div
//
2
)(
x
[:,
:,
2
])
x_a_effect
=
embed
(
256
,
c
//
div
*
2
)(
x
[:,
:,
3
])
x_a_phase
=
embed
(
4
,
c
//
div
//
2
)(
x
[:,
:,
4
])
x_a_position
=
embed
(
9
,
c
//
div
)(
x
[:,
:,
5
])
x_a_number
=
embed
(
13
,
c
//
div
//
2
)(
x
[:,
:,
6
])
x_a_place
=
embed
(
31
,
c
//
div
)(
x
[:,
:,
7
])
x_a_attrib
=
embed
(
10
,
c
//
div
//
2
)(
x
[:,
:,
8
])
xs
=
[
x_a_msg
,
x_a_act
,
x_a_finish
,
x_a_effect
,
x_a_phase
,
x_a_position
,
x_a_number
,
x_a_place
,
x_a_attrib
]
return
xs
class
CardEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
,
mask
):
assert
self
.
version
>
0
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x_loc
=
x1
[:,
:,
0
]
x_seq
=
x1
[:,
:,
1
]
c_mask
=
x_loc
==
0
c_mask
=
c_mask
.
at
[:,
0
]
.
set
(
False
)
x_owner
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
2
])
x_position
=
embed
(
9
,
c
//
16
)(
x1
[:,
:,
3
])
x_overley
=
embed
(
2
,
c
//
16
)(
x1
[:,
:,
4
])
x_attribute
=
embed
(
8
,
c
//
16
)(
x1
[:,
:,
5
])
x_race
=
embed
(
27
,
c
//
16
)(
x1
[:,
:,
6
])
x_level
=
embed
(
14
,
c
//
16
)(
x1
[:,
:,
7
])
x_counter
=
embed
(
16
,
c
//
16
)(
x1
[:,
:,
8
])
x_negated
=
embed
(
3
,
c
//
16
)(
x1
[:,
:,
9
])
x_atk
=
num_transform
(
x2
[:,
:,
0
:
2
])
x_atk
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_atk
)
x_def
=
num_transform
(
x2
[:,
:,
2
:
4
])
x_def
=
fc_embed
(
c
//
16
,
kernel_init
=
default_fc_init1
)(
x_def
)
x_type
=
fc_embed
(
c
//
16
*
2
,
kernel_init
=
default_fc_init2
)(
x2
[:,
:,
4
:])
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
feats
=
[
x_id
,
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
]
if
mask
is
not
None
:
assert
len
(
feats
)
==
mask
.
shape
[
-
1
]
feats
=
[
jnp
.
where
(
mask
[
...
,
i
:
i
+
1
]
==
1
,
f
,
f
[
...
,
-
1
:,
:])
for
i
,
f
in
enumerate
(
feats
)
]
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
feats
[
0
]
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
count_embed
=
embed
(
100
,
c
//
16
)
hand_count_embed
=
embed
(
100
,
c
//
16
)
num_fc
=
mlp
((
c
//
8
,),
last_lin
=
False
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
0
:
2
]))
x_oppo_lp
=
fc_embed
(
c
//
4
,
kernel_init
=
default_fc_init2
)(
num_transform
(
x1
[:,
2
:
4
]))
x_turn
=
embed
(
20
,
c
//
8
)(
x2
[:,
0
])
x_phase
=
embed
(
11
,
c
//
8
)(
x2
[:,
1
])
x_if_first
=
embed
(
2
,
c
//
8
)(
x2
[:,
2
])
x_is_my_turn
=
embed
(
2
,
c
//
8
)(
x2
[:,
3
])
x_cs
=
count_embed
(
x3
)
.
reshape
((
batch_size
,
-
1
))
x_my_hand_c
=
hand_count_embed
(
x3
[:,
1
])
x_op_hand_c
=
hand_count_embed
(
x3
[:,
8
])
x
=
jnp
.
concatenate
([
x_lp
,
x_oppo_lp
,
x_turn
,
x_phase
,
x_if_first
,
x_is_my_turn
,
x_cs
,
x_my_hand_c
,
x_op_hand_c
],
axis
=-
1
)
x
=
layer_norm
()(
x
)
return
x
class
Encoder
(
nn
.
Module
):
channels
:
int
=
128
out_channels
:
Optional
[
int
]
=
None
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
card_mask
:
bool
=
False
noam
:
bool
=
False
action_feats
:
bool
=
True
info_mask
:
bool
=
False
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
if
self
.
embedding_shape
is
None
:
n_embed
,
embed_dim
=
999
,
1024
elif
isinstance
(
self
.
embedding_shape
,
int
):
n_embed
,
embed_dim
=
self
.
embedding_shape
,
1024
else
:
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)
ActionEncoderCls
=
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
mask
=
x
[
'mask_'
]
if
self
.
info_mask
else
None
batch_size
=
x_global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
f_cards
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:],
mask
)
# Cards
g_card_embed
=
self
.
param
(
"g_card_embed"
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_g_card
=
jnp
.
tile
(
g_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_g_card
,
f_cards
],
axis
=
1
)
if
self
.
card_mask
:
c_mask
=
jnp
.
concatenate
([
jnp
.
zeros
((
batch_size
,
1
),
dtype
=
c_mask
.
dtype
),
c_mask
],
axis
=
1
)
else
:
c_mask
=
None
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_layers
):
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
# Global
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
if
self
.
version
==
2
:
x_global
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_global
)
f_global
=
x_global
+
GLUMlp
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
layer_norm
(
dtype
=
self
.
dtype
)(
x_global
))
else
:
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
1
:
3
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_id
)
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
3
:
12
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
12
])
x_h_a_phase
=
embed
(
12
,
c
//
2
)(
x_h_actions
[:,
:,
13
])
x_h_a_feats
.
extend
([
x_h_id
,
x_h_a_turn
,
x_h_a_phase
])
x_h_a_feats
=
jnp
.
concatenate
(
x_h_a_feats
,
axis
=-
1
)
x_h_a_feats
=
layer_norm
()(
x_h_a_feats
)
x_h_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_h_a_feats
)
if
self
.
noam
:
f_h_actions
=
LlamaEncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
rope
=
True
,
n_positions
=
64
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
else
:
x_h_a_feats
=
PositionalEncoding
()(
x_h_a_feats
)
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
# Actions
x_actions
=
x_actions
.
astype
(
jnp
.
int32
)
na_card_embed
=
self
.
param
(
'na_card_embed'
,
lambda
key
,
shape
,
dtype
:
jax
.
random
.
normal
(
key
,
shape
,
dtype
)
*
0.02
,
(
1
,
c
),
self
.
param_dtype
)
f_na_card
=
jnp
.
tile
(
na_card_embed
,
(
batch_size
,
1
,
1
))
.
astype
(
f_cards
.
dtype
)
f_cards
=
jnp
.
concatenate
([
f_na_card
,
f_cards
[:,
1
:]],
axis
=
1
)
spec_index
=
x_actions
[
...
,
0
]
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
x_a_id
=
decode_id
(
x_actions
[
...
,
1
:
3
])
x_a_id
=
id_embed
(
x_a_id
)
if
self
.
freeze_id
:
x_a_id
=
jax
.
lax
.
stop_gradient
(
x_a_id
)
x_a_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_a_id
)
x_a_feats
=
action_encoder
(
x_actions
[
...
,
3
:])
x_a_feats
.
append
(
x_a_id
)
x_a_feats
=
jnp
.
concatenate
(
x_a_feats
,
axis
=-
1
)
x_a_feats
=
layer_norm
()(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
f_actions
=
jax
.
nn
.
silu
(
f_a_cards
)
*
x_a_feats
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
f_actions
=
x_a_feats
+
f_actions
a_mask
=
x_actions
[:,
:,
3
]
==
0
a_mask
=
a_mask
.
at
[:,
0
]
.
set
(
False
)
g_feats
=
[
f_g_card
,
f_global
]
if
self
.
use_history
:
g_feats
.
append
(
f_g_h_actions
)
if
self
.
action_feats
:
f_actions_g
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions_g
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
g_feats
.
append
(
f_g_actions
)
f_state
=
jnp
.
concatenate
(
g_feats
,
axis
=-
1
)
oc
=
self
.
out_channels
or
c
if
self
.
version
==
2
:
f_state
=
GLUMlp
(
intermediate_size
=
c
*
2
,
output_size
=
oc
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
else
:
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
return
f_actions
,
f_state
,
a_mask
,
valid
class
Actor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
last_kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))
f_state
=
mlp
((
c
,),
use_bias
=
True
)(
f_state
)
logits
=
jnp
.
einsum
(
'bc,bnc->bn'
,
f_state
,
f_actions
)
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
FiLMActor
(
nn
.
Module
):
channels
:
int
=
128
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
noam
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
f_state
,
f_actions
,
mask
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
f_actions
=
f_actions
.
astype
(
self
.
dtype
)
c
=
self
.
channels
t
=
nn
.
Dense
(
c
*
4
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
a_s
,
a_b
,
o_s
,
o_b
=
jnp
.
split
(
t
[:,
None
,
:],
4
,
axis
=-
1
)
num_heads
=
max
(
2
,
c
//
128
)
f_actions
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_actions
,
mask
,
a_s
,
a_b
,
o_s
,
o_b
)
logits
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
0.01
))(
f_actions
)[:,
:,
0
]
big_neg
=
jnp
.
finfo
(
logits
.
dtype
)
.
min
logits
=
jnp
.
where
(
mask
,
big_neg
,
logits
)
return
logits
class
Critic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
train
):
f_state
=
f_state
.
astype
(
self
.
dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
False
)(
f_state
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
class
CrossCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
,
128
)
# dropout_rate: Optional[float] = None
batch_norm_momentum
:
float
=
0.99
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state
,
train
):
x
=
f_state
.
astype
(
self
.
dtype
)
linear
=
partial
(
nn
.
Dense
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
use_bias
=
False
)
BN
=
partial
(
BatchRenorm
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
momentum
=
self
.
batch_norm_momentum
,
axis_name
=
"local_devices"
,
use_running_average
=
not
train
)
x
=
BN
()(
x
)
for
c
in
self
.
channels
:
x
=
linear
(
c
)(
x
)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x
=
nn
.
relu
()(
x
)
# x = nn.leaky_relu(x, negative_slope=0.1)
x
=
BN
()(
x
)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)(
x
)
return
x
class
GlobalCritic
(
nn
.
Module
):
channels
:
Sequence
[
int
]
=
(
128
,
128
)
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
):
f_state
=
jnp
.
concatenate
([
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
],
axis
=-
1
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
# c = self.channels[-1]
# t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
# s, b = jnp.split(t, 2, axis=-1)
# x = x * s + b
# x = mlp([c], last_lin=False)(x)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
def
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
main
,
return_state
=
False
):
if
main
is
not
None
:
rstate1
,
rstate2
=
rstate
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
main
is
not
None
:
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate
=
rstate1
,
rstate2
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
if
return_state
:
return
rstate
,
(
f_state
,
rstate
)
else
:
return
rstate
,
f_state
def
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
,
switch
=
True
,
return_state
=
False
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
return
rnn_step_by_main
(
cell
,
carry
,
x
,
done
,
main
,
return_state
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
def
rnn_step
(
rnn_layer
,
rstate
,
f_state
,
done
):
rstate
,
f_state
=
rnn_layer
(
rstate
,
f_state
)
if
done
is
not
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
f_state
def
rnn_forward
(
rnn_layer
,
rstate
,
f_state
,
done
):
def
body_fn
(
cell
,
carry
,
x
,
done
):
return
rnn_step
(
cell
,
carry
,
x
,
done
)
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
rstate
,
f_state
,
done
)
return
rstate
,
f_state
@
dataclass
class
EncoderArgs
:
num_layers
:
int
=
1
"""the number of layers for the agent"""
num_channels
:
int
=
64
"""the number of channels for the agent"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
card_mask
:
bool
=
False
"""whether to mask the padding card as ignored in the transformer"""
noam
:
bool
=
False
"""whether to use Noam architecture for the transformer layer"""
action_feats
:
bool
=
True
"""whether to use action features for the global state"""
version
:
int
=
0
"""the version of the environment and the agent"""
@
dataclass
class
ModelArgs
(
EncoderArgs
):
rnn_channels
:
int
=
256
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
Literal
[
'lstm'
,
'gru'
,
'rwkv'
,
'none'
]]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
film
:
bool
=
False
"""whether to use FiLM for the actor"""
rnn_shortcut
:
bool
=
False
"""whether to use shortcut for the RNN"""
batch_norm
:
bool
=
False
"""whether to use batch normalization for the critic"""
critic_width
:
int
=
256
"""the width of the critic"""
critic_depth
:
int
=
3
"""the depth of the critic"""
rwkv_head_size
:
int
=
32
"""the head size for the RWKV"""
class
RNNAgent
(
nn
.
Module
):
num_layers
:
int
=
1
num_channels
:
int
=
64
rnn_channels
:
int
=
256
use_history
:
bool
=
True
card_mask
:
bool
=
False
rnn_type
:
str
=
'lstm'
film
:
bool
=
False
noam
:
bool
=
False
rwkv_head_size
:
int
=
32
action_feats
:
bool
=
True
rnn_shortcut
:
bool
=
False
batch_norm
:
bool
=
False
critic_width
:
int
=
256
critic_depth
:
int
=
3
version
:
int
=
0
eval
:
bool
=
False
actor
:
bool
=
True
freeze_id
:
bool
=
False
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
,
rstate
,
done
=
None
,
switch_or_main
=
None
,
train
=
False
):
batch_size
=
jax
.
tree
.
leaves
(
rstate
)[
0
]
.
shape
[
0
]
c
=
self
.
num_channels
oc
=
self
.
rnn_channels
if
self
.
rnn_type
==
'rwkv'
else
None
encoder
=
Encoder
(
channels
=
c
,
out_channels
=
oc
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
card_mask
=
self
.
card_mask
,
noam
=
self
.
noam
,
action_feats
=
self
.
action_feats
,
info_mask
=
False
if
self
.
eval
else
self
.
actor
,
version
=
self
.
version
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
if
self
.
rnn_type
in
[
'lstm'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'rwkv'
:
num_heads
=
self
.
rnn_channels
//
self
.
rwkv_head_size
rnn_layer
=
Rwkv6SelfAttention
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
else
:
num_steps
=
f_state
.
shape
[
0
]
//
batch_size
multi_step
=
num_steps
>
1
if
done
is
None
:
assert
not
multi_step
if
multi_step
:
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
if
self
.
actor
:
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate
,
f_state_r
,
done
,
switch_or_main
,
False
)
else
:
rstate
,
f_state_r
=
rnn_forward
(
rnn_layer
,
rstate
,
f_state_r
,
done
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
if
self
.
actor
:
rstate
,
f_state_r
=
rnn_step_by_main
(
rnn_layer
,
rstate
,
f_state
,
done
,
switch_or_main
)
else
:
rstate
,
f_state_r
=
rnn_step
(
rnn_layer
,
rstate
,
f_state
,
done
)
if
self
.
rnn_shortcut
:
f_state_r
=
jnp
.
concatenate
([
f_state
,
f_state_r
],
axis
=-
1
)
if
self
.
actor
:
if
self
.
film
:
actor
=
FiLMActor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
noam
=
self
.
noam
)
else
:
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
return
rstate
,
logits
else
:
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
return
rstate
,
value
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
elif
self
.
rnn_type
==
'rwkv'
:
head_size
=
self
.
rwkv_head_size
num_heads
=
self
.
rnn_channels
//
self
.
rwkv_head_size
return
(
np
.
zeros
((
batch_size
,
num_heads
*
head_size
)),
np
.
zeros
((
batch_size
,
num_heads
*
head_size
*
head_size
)),
)
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/utils.py
View file @
72a1fd28
...
...
@@ -10,8 +10,6 @@ import optax
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
def
masked_mean
(
x
,
valid
):
x
=
jnp
.
where
(
valid
,
x
,
jnp
.
zeros_like
(
x
))
...
...
ygoai/rl/utils.py
View file @
72a1fd28
import
re
import
numpy
as
np
import
gymnasium
as
gym
import
pickle
import
optree
import
torch
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
,
EnvPreprocess
def
split_param_groups
(
model
,
regex
):
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
72a1fd28
...
...
@@ -1540,7 +1540,7 @@ public:
Spec
<
uint8_t
>
({
conf
[
"max_options"
_
],
n_action_feats
})),
"obs:h_actions_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"n_history_actions"
_
],
n_action_feats
+
2
})),
"obs:
g_cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:
mask_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
14
})),
"info:num_options"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
conf
[
"max_options"
_
]
-
1
})),
"info:to_play"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
"info:is_selfplay"
_
.
Bind
(
Spec
<
int
>
({},
{
0
,
1
})),
...
...
@@ -2337,9 +2337,18 @@ public:
return
;
}
auto
[
spec_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
SpecInfos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
if
(
spec_
.
config
[
"oppo_info"
_
])
{
_set_obs_g_cards
(
state
[
"obs:g_cards_"
_
]);
_set_obs_g_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
auto
[
spec_infos_
,
loc_n_cards_
]
=
_set_obs_mask
(
state
[
"obs:mask_"
_
],
to_play_
);
spec_infos
=
spec_infos_
;
loc_n_cards
=
loc_n_cards_
;
}
else
{
auto
[
spec_infos_
,
loc_n_cards_
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
spec_infos
=
spec_infos_
;
loc_n_cards
=
loc_n_cards_
;
}
_set_obs_global
(
state
[
"obs:global_"
_
],
to_play_
,
loc_n_cards
);
...
...
@@ -2448,27 +2457,85 @@ private:
return
{
spec_infos
,
loc_n_cards
};
}
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
)
{
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
,
PlayerId
to_play
)
{
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
std
::
vector
<
uint8_t
>
configs
=
{
LOCATION_DECK
,
LOCATION_HAND
,
LOCATION_MZONE
,
LOCATION_SZONE
,
LOCATION_GRAVE
,
LOCATION_REMOVED
,
LOCATION_EXTRA
,
};
for
(
auto
location
:
configs
)
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
p
i
,
location
);
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
p
layer
,
location
);
int
n_cards
=
cards
.
size
();
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
CardId
card_id
=
c_get_card_id
(
c
.
code_
);
_set_obs_card_
(
f_cards
,
offset
,
c
,
false
,
card_id
,
false
);
offset
++
;
if
(
offset
==
(
spec_
.
config
[
"max_cards"
_
]
*
2
-
1
))
{
return
;
}
}
}
}
}
std
::
tuple
<
SpecInfos
,
std
::
vector
<
int
>>
_set_obs_mask
(
TArray
<
uint8_t
>
&
mask
,
PlayerId
to_play
)
{
SpecInfos
spec_infos
;
std
::
vector
<
int
>
loc_n_cards
;
int
offset
=
0
;
for
(
auto
pi
=
0
;
pi
<
2
;
pi
++
)
{
const
PlayerId
player
=
(
to_play
+
pi
)
%
2
;
const
bool
opponent
=
pi
==
1
;
std
::
vector
<
std
::
pair
<
uint8_t
,
bool
>>
configs
=
{
{
LOCATION_DECK
,
true
},
{
LOCATION_HAND
,
true
},
{
LOCATION_MZONE
,
false
},
{
LOCATION_SZONE
,
false
},
{
LOCATION_GRAVE
,
false
},
{
LOCATION_REMOVED
,
false
},
{
LOCATION_EXTRA
,
true
},
};
for
(
auto
&
[
location
,
hidden_for_opponent
]
:
configs
)
{
// check this
if
(
opponent
&&
(
revealed_
.
size
()
!=
0
))
{
hidden_for_opponent
=
false
;
}
if
(
opponent
&&
hidden_for_opponent
)
{
auto
n_cards
=
YGO_QueryFieldCount
(
pduel_
,
player
,
location
);
loc_n_cards
.
push_back
(
n_cards
);
for
(
auto
i
=
0
;
i
<
n_cards
;
i
++
)
{
mask
(
offset
,
1
)
=
1
;
mask
(
offset
,
3
)
=
1
;
offset
++
;
}
}
else
{
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
player
,
location
);
int
n_cards
=
cards
.
size
();
loc_n_cards
.
push_back
(
n_cards
);
for
(
int
i
=
0
;
i
<
n_cards
;
++
i
)
{
const
auto
&
c
=
cards
[
i
];
auto
spec
=
c
.
get_spec
(
opponent
);
bool
hide
=
false
;
if
(
opponent
)
{
hide
=
c
.
position_
&
POS_FACEDOWN
;
if
(
revealed_
.
find
(
spec
)
!=
revealed_
.
end
())
{
hide
=
false
;
}
}
CardId
card_id
=
0
;
if
(
!
hide
)
{
card_id
=
c_get_card_id
(
c
.
code_
);
}
_set_obs_mask_
(
mask
,
offset
,
c
,
hide
);
offset
++
;
spec_infos
[
spec
]
=
{
static_cast
<
uint16_t
>
(
offset
),
card_id
};
}
}
}
}
return
{
spec_infos
,
loc_n_cards
};
}
void
_set_obs_card_
(
TArray
<
uint8_t
>
&
f_cards
,
int
offset
,
const
Card
&
c
,
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
...
...
@@ -2531,6 +2598,54 @@ private:
}
}
void
_set_obs_mask_
(
TArray
<
uint8_t
>
&
mask
,
int
offset
,
const
Card
&
c
,
bool
hide
,
CardId
card_id
=
0
,
bool
global
=
false
)
{
// check offset exceeds max_cards
uint8_t
location
=
c
.
location_
;
bool
overlay
=
location
&
LOCATION_OVERLAY
;
if
(
overlay
)
{
location
=
location
&
0x7f
;
}
if
(
overlay
)
{
hide
=
false
;
}
if
(
!
hide
)
{
if
(
card_id
!=
0
)
{
mask
(
offset
,
0
)
=
1
;
}
}
mask
(
offset
,
1
)
=
1
;
if
(
location
==
LOCATION_MZONE
||
location
==
LOCATION_SZONE
||
location
==
LOCATION_GRAVE
)
{
mask
(
offset
,
2
)
=
1
;
}
mask
(
offset
,
3
)
=
1
;
if
(
overlay
)
{
mask
(
offset
,
4
)
=
1
;
mask
(
offset
,
5
)
=
1
;
}
else
{
if
(
location
==
LOCATION_DECK
||
location
==
LOCATION_HAND
||
location
==
LOCATION_EXTRA
)
{
if
(
hide
||
(
c
.
position_
&
POS_FACEDOWN
))
{
mask
(
offset
,
4
)
=
1
;
}
}
else
{
mask
(
offset
,
4
)
=
1
;
}
}
if
(
!
hide
)
{
mask
(
offset
,
6
)
=
1
;
mask
(
offset
,
7
)
=
1
;
mask
(
offset
,
8
)
=
1
;
mask
(
offset
,
9
)
=
1
;
mask
(
offset
,
10
)
=
1
;
mask
(
offset
,
11
)
=
1
;
mask
(
offset
,
12
)
=
1
;
mask
(
offset
,
13
)
=
1
;
}
}
void
_set_obs_global
(
TArray
<
uint8_t
>
&
feat
,
PlayerId
player
,
const
std
::
vector
<
int
>
&
loc_n_cards
)
{
uint8_t
me
=
player
;
uint8_t
op
=
1
-
player
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment