Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
72a1fd28
Commit
72a1fd28
authored
Jun 19, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add oppo_info
parent
5cd9807d
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2068 additions
and
53 deletions
+2068
-53
scripts/cleanba.py
scripts/cleanba.py
+5
-5
scripts/cleanba_g.py
scripts/cleanba_g.py
+1213
-0
ygoai/rl/env.py
ygoai/rl/env.py
+27
-2
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+42
-35
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+660
-0
ygoai/rl/jax/utils.py
ygoai/rl/jax/utils.py
+0
-2
ygoai/rl/utils.py
ygoai/rl/utils.py
+1
-4
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+120
-5
No files found.
scripts/cleanba.py
View file @
72a1fd28
...
...
@@ -23,9 +23,10 @@ from rich.pretty import pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
EnvPreprocess
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.utils
import
masked_normalize
,
categorical_sample
,
TrainState
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.switch
import
truncated_gae_sep
as
gae_sep_switch
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
\
...
...
@@ -356,6 +357,7 @@ def rollout(
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
EnvPreprocess
(
envs
,
skip_mask
=
not
args
.
m1
.
oppo_info
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
...
...
@@ -363,6 +365,7 @@ def rollout(
local_seed
+
100000
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
EnvPreprocess
(
eval_envs
,
skip_mask
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -440,9 +443,6 @@ def rollout(
init_rstates
=
[]
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
stack
(
xs
),
*
storage
)
...
...
@@ -566,7 +566,7 @@ def rollout(
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
if
v
is
not
None
else
None
for
k
,
v
in
x
.
items
()
}
elif
x
is
not
None
:
...
...
scripts/cleanba_g.py
0 → 100644
View file @
72a1fd28
This diff is collapsed.
Click to expand it.
ygoai/rl/env.py
View file @
72a1fd28
...
...
@@ -60,11 +60,11 @@ class RecordEpisodeStatistics(gym.Wrapper):
class
CompatEnv
(
gym
.
Wrapper
):
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
s
uper
()
.
reset
(
**
kwargs
)
observations
,
infos
=
s
elf
.
env
.
reset
(
**
kwargs
)
return
observations
,
infos
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
elf
.
env
.
step
(
action
)
observations
,
rewards
,
terminated
,
truncated
,
infos
=
s
uper
()
.
step
(
action
)
dones
=
np
.
logical_or
(
terminated
,
truncated
)
return
(
observations
,
...
...
@@ -72,3 +72,28 @@ class CompatEnv(gym.Wrapper):
dones
,
infos
,
)
class
EnvPreprocess
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
skip_mask
):
super
()
.
__init__
(
env
)
self
.
skip_mask
=
skip_mask
def
reset
(
self
,
**
kwargs
):
observations
,
infos
=
self
.
env
.
reset
(
**
kwargs
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
observations
,
infos
def
step
(
self
,
action
):
observations
,
rewards
,
terminated
,
truncated
,
infos
=
super
()
.
step
(
action
)
if
self
.
skip_mask
:
observations
[
'mask_'
]
=
None
return
(
observations
,
rewards
,
terminated
,
truncated
,
infos
,
)
\ No newline at end of file
ygoai/rl/jax/agent.py
View file @
72a1fd28
...
...
@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
version
:
int
=
0
@
nn
.
compact
def
__call__
(
self
,
x_id
,
x
):
def
__call__
(
self
,
x_id
,
x
,
mask
):
assert
self
.
version
>
0
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
...
...
@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
x_f
=
layer_norm
()(
x_f
)
f_cards
=
jnp
.
concatenate
([
x_id
,
x_f
],
axis
=-
1
)
f_cards
=
f_cards
+
f_loc
+
f_seq
f_cards_g
=
None
else
:
x_id
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_id
)
x_id
=
jax
.
nn
.
swish
(
x_id
)
f_loc
=
embed
(
9
,
c
//
16
*
2
)(
x_loc
)
f_seq
=
embed
(
76
,
c
//
16
*
2
)(
x_seq
)
x_cards
=
jnp
.
concatenate
([
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
],
axis
=-
1
)
feats_g
=
[
x_id
,
f_loc
,
f_seq
,
x_owner
,
x_position
,
x_overley
,
x_attribute
,
x_race
,
x_level
,
x_counter
,
x_negated
,
x_atk
,
x_def
,
x_type
]
if
mask
is
not
None
:
assert
len
(
feats_g
)
==
mask
.
shape
[
-
1
]
feats
=
[
jnp
.
where
(
mask
[
...
,
i
:
i
+
1
]
==
1
,
f
,
f
[
...
,
-
1
:,
:])
for
i
,
f
in
enumerate
(
feats_g
)
]
else
:
feats
=
feats_g
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
x_id
x_cards
=
x_cards
*
feats
[
0
]
f_cards
=
layer_norm
()(
x_cards
)
return
f_cards
,
c_mask
if
self
.
oppo_info
:
x_cards_g
=
jnp
.
concatenate
(
feats_g
[
1
:],
axis
=-
1
)
x_cards_g
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards_g
)
x_cards_g
=
x_cards_g
*
feats_g
[
0
]
f_cards_g
=
layer_norm
()(
x_cards_g
)
else
:
f_cards_g
=
None
return
f_cards_g
,
f_cards
,
c_mask
class
GlobalEncoder
(
nn
.
Module
):
...
...
@@ -229,35 +247,26 @@ class Encoder(nn.Module):
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
x_cards_g
=
x
[
'g_cards_'
]
if
self
.
oppo_info
else
None
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_actions
=
x
[
'actions_'
]
x_h_actions
=
x
[
'h_actions_'
]
mask
=
x
[
'mask_'
]
batch_size
=
x_global
.
shape
[
0
]
valid
=
x_global
[:,
-
1
]
==
0
n_cards
=
x_cards
.
shape
[
-
2
]
if
self
.
oppo_info
:
x_cards
=
jnp
.
concatenate
([
x_cards
,
x_cards_g
],
axis
=-
2
)
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
f_cards
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:])
if
self
.
oppo_info
:
f_cards_me
,
f_cards_g
=
jnp
.
split
(
f_cards
,
[
n_cards
],
axis
=-
2
)
else
:
f_cards_me
,
f_cards_g
=
f_cards
,
None
f_cards_g
,
f_cards_me
,
c_mask
=
card_encoder
(
x_id
,
x_cards
[:,
:,
2
:],
mask
)
# Cards
fs_g_card
=
[]
...
...
@@ -528,17 +537,16 @@ class GlobalCritic(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
rstate1
,
rstate2
,
g_cards
):
f_state
=
jnp
.
concatenate
([
rstate1
[
0
],
rstate1
[
1
],
rstate2
[
0
],
rstate2
[
0
]
],
axis
=-
1
)
def
__call__
(
self
,
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
):
f_state
=
jnp
.
concatenate
([
f_state_r1
,
f_state_r2
,
f_state
,
g_cards
],
axis
=-
1
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x
=
mlp
(
self
.
channels
,
last_lin
=
True
)(
f_state
)
c
=
self
.
channels
[
-
1
]
t
=
nn
.
Dense
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
g_cards
)
s
,
b
=
jnp
.
split
(
t
,
2
,
axis
=-
1
)
x
=
x
*
s
+
b
x
=
mlp
([
c
],
last_lin
=
False
)(
x
)
# c = self.channels[-1]
# t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
# s, b = jnp.split(t, 2, axis=-1)
# x = x * s + b
# x = mlp([c], last_lin=False)(x)
x
=
nn
.
Dense
(
1
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))(
x
)
return
x
...
...
@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
self
.
oppo_info
:
critic
=
GlobalCritic
(
channels
=
[
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
not
multi_step
:
if
isinstance
(
rstate
[
0
],
tuple
):
rstate1_t
,
rstate2_t
=
rstate
...
...
@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x1
,
x2
),
rstate1
,
rstate2
)
rstate2_t
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
,
x2
,
x1
),
rstate1
,
rstate2
)
value
=
critic
(
rstate1_t
,
rstate2_t
,
f_g
)
f_critic
=
jnp
.
concatenate
([
rstate1_t
[
1
],
rstate2_t
[
1
],
f_state
,
f_g
],
axis
=-
1
)
value
=
critic
(
f_critic
,
train
)
else
:
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
value
=
critic
(
f_state_r
,
train
)
if
self
.
int_head
:
...
...
ygoai/rl/jax/agent2.py
0 → 100644
View file @
72a1fd28
This diff is collapsed.
Click to expand it.
ygoai/rl/jax/utils.py
View file @
72a1fd28
...
...
@@ -10,8 +10,6 @@ import optax
import
numpy
as
np
from
ygoai.rl.env
import
RecordEpisodeStatistics
def
masked_mean
(
x
,
valid
):
x
=
jnp
.
where
(
valid
,
x
,
jnp
.
zeros_like
(
x
))
...
...
ygoai/rl/utils.py
View file @
72a1fd28
import
re
import
numpy
as
np
import
gymnasium
as
gym
import
pickle
import
optree
import
torch
from
ygoai.rl.env
import
RecordEpisodeStatistics
from
ygoai.rl.env
import
RecordEpisodeStatistics
,
EnvPreprocess
def
split_param_groups
(
model
,
regex
):
...
...
ygoenv/ygoenv/ygopro/ygopro.h
View file @
72a1fd28
...
...
@@ -1540,7 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:
g_cards_"
_
.
Bind
(
Spec
<
uint8_t
>
({
conf
[
"max_cards"
_
]
*
2
,
41
})),
"obs:
mask_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 14
})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
...
...
@@ -2337,9 +2337,18 @@ public:
return;
}
auto
[
spec_infos
,
loc_n_cards
]
=
_set_obs_cards
(
state
[
"obs:cards_"
_
],
to_play_
);
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
if (spec_.config["oppo_info"_]) {
_set_obs_g_cards
(
state
[
"obs:g_cards_"
_
]);
_set_obs_g_cards(state["obs:cards_"_], to_play_);
auto [spec_infos_, loc_n_cards_] = _set_obs_mask(state["obs:mask_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
} else {
auto [spec_infos_, loc_n_cards_] = _set_obs_cards(state["obs:cards_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
}
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
...
...
@@ -2448,27 +2457,85 @@ private:
return {spec_infos, loc_n_cards};
}
void
_set_obs_g_cards
(
TArray
<
uint8_t
>
&
f_cards
)
{
void _set_obs_g_cards(TArray<uint8_t> &f_cards
, PlayerId to_play
) {
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA,
};
for (auto location : configs) {
std
::
vector
<
Card
>
cards
=
get_cards_in_location
(
p
i
,
location
);
std::vector<Card> cards = get_cards_in_location(p
layer
, location);
int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++;
if (offset == (spec_.config["max_cards"_] * 2 - 1)) {
return;
}
}
}
}
}
std::tuple<SpecInfos, std::vector<int>> _set_obs_mask(TArray<uint8_t> &mask, PlayerId to_play) {
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1;
std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false},
{LOCATION_GRAVE, false}, {LOCATION_REMOVED, false},
{LOCATION_EXTRA, true},
};
for (auto &[location, hidden_for_opponent] : configs) {
// check this
if (opponent && (revealed_.size() != 0)) {
hidden_for_opponent = false;
}
if (opponent && hidden_for_opponent) {
auto n_cards = YGO_QueryFieldCount(pduel_, player, location);
loc_n_cards.push_back(n_cards);
for (auto i = 0; i < n_cards; i++) {
mask(offset, 1) = 1;
mask(offset, 3) = 1;
offset++;
}
} else {
std::vector<Card> cards = get_cards_in_location(player, location);
int n_cards = cards.size();
loc_n_cards.push_back(n_cards);
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
auto spec = c.get_spec(opponent);
bool hide = false;
if (opponent) {
hide = c.position_ & POS_FACEDOWN;
if (revealed_.find(spec) != revealed_.end()) {
hide = false;
}
}
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
_set_obs_mask_(mask, offset, c, hide);
offset++;
spec_infos[spec] = {static_cast<uint16_t>(offset), card_id};
}
}
}
}
return {spec_infos, loc_n_cards};
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) {
...
...
@@ -2531,6 +2598,54 @@ private:
}
}
void _set_obs_mask_(TArray<uint8_t> &mask, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards
uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY;
if (overlay) {
location = location & 0x7f;
}
if (overlay) {
hide = false;
}
if (!hide) {
if (card_id != 0) {
mask(offset, 0) = 1;
}
}
mask(offset, 1) = 1;
if (location == LOCATION_MZONE || location == LOCATION_SZONE ||
location == LOCATION_GRAVE) {
mask(offset, 2) = 1;
}
mask(offset, 3) = 1;
if (overlay) {
mask(offset, 4) = 1;
mask(offset, 5) = 1;
} else {
if (location == LOCATION_DECK || location == LOCATION_HAND || location == LOCATION_EXTRA) {
if (hide || (c.position_ & POS_FACEDOWN)) {
mask(offset, 4) = 1;
}
} else {
mask(offset, 4) = 1;
}
}
if (!hide) {
mask(offset, 6) = 1;
mask(offset, 7) = 1;
mask(offset, 8) = 1;
mask(offset, 9) = 1;
mask(offset, 10) = 1;
mask(offset, 11) = 1;
mask(offset, 12) = 1;
mask(offset, 13) = 1;
}
}
void _set_obs_global(TArray<uint8_t> &feat, PlayerId player, const std::vector<int> &loc_n_cards) {
uint8_t me = player;
uint8_t op = 1 - player;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment