Commit 72a1fd28 authored by sbl1996@126.com's avatar sbl1996@126.com

Add oppo_info

parent 5cd9807d
...@@ -23,9 +23,10 @@ from rich.pretty import pprint ...@@ -23,9 +23,10 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.utils import RecordEpisodeStatistics, EnvPreprocess
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample, TrainState from ygoai.rl.jax.utils import masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \ from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
...@@ -356,6 +357,7 @@ def rollout( ...@@ -356,6 +357,7 @@ def rollout(
args.local_env_threads, args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads, thread_affinity_offset=device_thread_id * args.local_env_threads,
) )
envs = EnvPreprocess(envs, skip_mask=not args.m1.oppo_info)
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
eval_envs = make_env( eval_envs = make_env(
...@@ -363,6 +365,7 @@ def rollout( ...@@ -363,6 +365,7 @@ def rollout(
local_seed + 100000, local_seed + 100000,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True) args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = EnvPreprocess(eval_envs, skip_mask=True)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -440,9 +443,6 @@ def rollout( ...@@ -440,9 +443,6 @@ def rollout(
init_rstates = [] init_rstates = []
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.stack(xs), *storage) return jax.tree.map(lambda *xs: jnp.stack(xs), *storage)
...@@ -566,7 +566,7 @@ def rollout( ...@@ -566,7 +566,7 @@ def rollout(
for x in partitioned_storage: for x in partitioned_storage:
if isinstance(x, dict): if isinstance(x, dict):
x = { x = {
k: jax.device_put_sharded(v, devices=learner_devices) k: jax.device_put_sharded(v, devices=learner_devices) if v is not None else None
for k, v in x.items() for k, v in x.items()
} }
elif x is not None: elif x is not None:
......
This diff is collapsed.
...@@ -60,15 +60,40 @@ class RecordEpisodeStatistics(gym.Wrapper): ...@@ -60,15 +60,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
class CompatEnv(gym.Wrapper): class CompatEnv(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
observations, infos = super().reset(**kwargs) observations, infos = self.env.reset(**kwargs)
return observations, infos return observations, infos
def step(self, action): def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action) observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated) dones = np.logical_or(terminated, truncated)
return ( return (
observations, observations,
rewards, rewards,
dones, dones,
infos, infos,
)
class EnvPreprocess(gym.Wrapper):
def __init__(self, env, skip_mask):
super().__init__(env)
self.skip_mask = skip_mask
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
if self.skip_mask:
observations['mask_'] = None
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
if self.skip_mask:
observations['mask_'] = None
return (
observations,
rewards,
terminated,
truncated,
infos,
) )
\ No newline at end of file
...@@ -85,7 +85,8 @@ class CardEncoder(nn.Module): ...@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
version: int = 0 version: int = 0
@nn.compact @nn.compact
def __call__(self, x_id, x): def __call__(self, x_id, x, mask):
assert self.version > 0
c = self.channels c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True) layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
...@@ -136,18 +137,35 @@ class CardEncoder(nn.Module): ...@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
x_f = layer_norm()(x_f) x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1) f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq f_cards = f_cards + f_loc + f_seq
f_cards_g = None
else: else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id) x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id) x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc) f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq) f_seq = embed(76, c // 16 * 2)(x_seq)
x_cards = jnp.concatenate([ feats_g = [
f_loc, f_seq, x_owner, x_position, x_overley, x_attribute, x_id, f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type], axis=-1) x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type]
if mask is not None:
assert len(feats_g) == mask.shape[-1]
feats = [
jnp.where(mask[..., i:i+1] == 1, f, f[..., -1:, :])
for i, f in enumerate(feats_g)
]
else:
feats = feats_g
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards) x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * x_id x_cards = x_cards * feats[0]
f_cards = layer_norm()(x_cards) f_cards = layer_norm()(x_cards)
return f_cards, c_mask if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g)
else:
f_cards_g = None
return f_cards_g, f_cards, c_mask
class GlobalEncoder(nn.Module): class GlobalEncoder(nn.Module):
...@@ -229,35 +247,26 @@ class Encoder(nn.Module): ...@@ -229,35 +247,26 @@ class Encoder(nn.Module):
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
card_encoder = CardEncoder( card_encoder = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1 ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls( action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards_g = x['g_cards_'] if self.oppo_info else None
x_cards = x['cards_'] x_cards = x['cards_']
x_global = x['global_'] x_global = x['global_']
x_actions = x['actions_'] x_actions = x['actions_']
x_h_actions = x['h_actions_'] x_h_actions = x['h_actions_']
mask = x['mask_']
batch_size = x_global.shape[0] batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0 valid = x_global[:, -1] == 0
n_cards = x_cards.shape[-2]
if self.oppo_info:
x_cards = jnp.concatenate([x_cards, x_cards_g], axis=-2)
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32)) x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id) x_id = id_embed(x_id)
if self.freeze_id: if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id) x_id = jax.lax.stop_gradient(x_id)
f_cards, c_mask = card_encoder(x_id, x_cards[:, :, 2:]) f_cards_g, f_cards_me, c_mask = card_encoder(x_id, x_cards[:, :, 2:], mask)
if self.oppo_info:
f_cards_me, f_cards_g = jnp.split(f_cards, [n_cards], axis=-2)
else:
f_cards_me, f_cards_g = f_cards, None
# Cards # Cards
fs_g_card = [] fs_g_card = []
...@@ -526,19 +535,18 @@ class GlobalCritic(nn.Module): ...@@ -526,19 +535,18 @@ class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128) channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, rstate1, rstate2, g_cards): def __call__(self, f_state_r1, f_state_r2, f_state, g_cards):
f_state = jnp.concatenate([rstate1[0], rstate1[1], rstate2[0], rstate2[0]], axis=-1) f_state = jnp.concatenate([f_state_r1, f_state_r2, f_state, g_cards], axis=-1)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=True)(f_state) x = mlp(self.channels, last_lin=True)(f_state)
c = self.channels[-1] # c = self.channels[-1]
t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards) # t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
s, b = jnp.split(t, 2, axis=-1) # s, b = jnp.split(t, 2, axis=-1)
x = x * s + b # x = x * s + b
# x = mlp([c], last_lin=False)(x)
x = mlp([c], last_lin=False)(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x) x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x return x
...@@ -720,9 +728,11 @@ class RNNAgent(nn.Module): ...@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask) logits = actor(f_state_r, f_actions, mask)
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
if self.oppo_info: if self.oppo_info:
critic = GlobalCritic(
channels=[c, c], dtype=self.dtype, param_dtype=self.param_dtype)
if not multi_step: if not multi_step:
if isinstance(rstate[0], tuple): if isinstance(rstate[0], tuple):
rstate1_t, rstate2_t = rstate rstate1_t, rstate2_t = rstate
...@@ -735,12 +745,9 @@ class RNNAgent(nn.Module): ...@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x1, x2), rstate1, rstate2) lambda x1, x2: jnp.where(main, x1, x2), rstate1, rstate2)
rstate2_t = jax.tree.map( rstate2_t = jax.tree.map(
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2) lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g) f_critic = jnp.concatenate([rstate1_t[1], rstate2_t[1], f_state, f_g], axis=-1)
value = critic(f_critic, train)
else: else:
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train) value = critic(f_state_r, train)
if self.int_head: if self.int_head:
......
This diff is collapsed.
...@@ -10,8 +10,6 @@ import optax ...@@ -10,8 +10,6 @@ import optax
import numpy as np import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid): def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x)) x = jnp.where(valid, x, jnp.zeros_like(x))
......
import re import re
import numpy as np
import gymnasium as gym
import pickle
import optree import optree
import torch import torch
from ygoai.rl.env import RecordEpisodeStatistics from ygoai.rl.env import RecordEpisodeStatistics, EnvPreprocess
def split_param_groups(model, regex): def split_param_groups(model, regex):
......
...@@ -1540,7 +1540,7 @@ public: ...@@ -1540,7 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})), Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind( "obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})), Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:g_cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})), "obs:mask_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 14})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})), "info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})), "info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})), "info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
...@@ -2337,9 +2337,18 @@ public: ...@@ -2337,9 +2337,18 @@ public:
return; return;
} }
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_); SpecInfos spec_infos;
std::vector<int> loc_n_cards;
if (spec_.config["oppo_info"_]) { if (spec_.config["oppo_info"_]) {
_set_obs_g_cards(state["obs:g_cards_"_]); _set_obs_g_cards(state["obs:cards_"_], to_play_);
auto [spec_infos_, loc_n_cards_] = _set_obs_mask(state["obs:mask_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
} else {
auto [spec_infos_, loc_n_cards_] = _set_obs_cards(state["obs:cards_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
} }
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards); _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
...@@ -2448,27 +2457,85 @@ private: ...@@ -2448,27 +2457,85 @@ private:
return {spec_infos, loc_n_cards}; return {spec_infos, loc_n_cards};
} }
void _set_obs_g_cards(TArray<uint8_t> &f_cards) { void _set_obs_g_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
int offset = 0; int offset = 0;
for (auto pi = 0; pi < 2; pi++) { for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
std::vector<uint8_t> configs = { std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE, LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED, LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA, LOCATION_EXTRA,
}; };
for (auto location : configs) { for (auto location : configs) {
std::vector<Card> cards = get_cards_in_location(pi, location); std::vector<Card> cards = get_cards_in_location(player, location);
int n_cards = cards.size(); int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) { for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i]; const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_); CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false); _set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++; offset++;
if (offset == (spec_.config["max_cards"_] * 2 - 1)) {
return;
}
} }
} }
} }
} }
std::tuple<SpecInfos, std::vector<int>> _set_obs_mask(TArray<uint8_t> &mask, PlayerId to_play) {
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1;
std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false},
{LOCATION_GRAVE, false}, {LOCATION_REMOVED, false},
{LOCATION_EXTRA, true},
};
for (auto &[location, hidden_for_opponent] : configs) {
// check this
if (opponent && (revealed_.size() != 0)) {
hidden_for_opponent = false;
}
if (opponent && hidden_for_opponent) {
auto n_cards = YGO_QueryFieldCount(pduel_, player, location);
loc_n_cards.push_back(n_cards);
for (auto i = 0; i < n_cards; i++) {
mask(offset, 1) = 1;
mask(offset, 3) = 1;
offset++;
}
} else {
std::vector<Card> cards = get_cards_in_location(player, location);
int n_cards = cards.size();
loc_n_cards.push_back(n_cards);
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
auto spec = c.get_spec(opponent);
bool hide = false;
if (opponent) {
hide = c.position_ & POS_FACEDOWN;
if (revealed_.find(spec) != revealed_.end()) {
hide = false;
}
}
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
_set_obs_mask_(mask, offset, c, hide);
offset++;
spec_infos[spec] = {static_cast<uint16_t>(offset), card_id};
}
}
}
}
return {spec_infos, loc_n_cards};
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c, void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) { bool hide, CardId card_id = 0, bool global = false) {
...@@ -2531,6 +2598,54 @@ private: ...@@ -2531,6 +2598,54 @@ private:
} }
} }
void _set_obs_mask_(TArray<uint8_t> &mask, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards
uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY;
if (overlay) {
location = location & 0x7f;
}
if (overlay) {
hide = false;
}
if (!hide) {
if (card_id != 0) {
mask(offset, 0) = 1;
}
}
mask(offset, 1) = 1;
if (location == LOCATION_MZONE || location == LOCATION_SZONE ||
location == LOCATION_GRAVE) {
mask(offset, 2) = 1;
}
mask(offset, 3) = 1;
if (overlay) {
mask(offset, 4) = 1;
mask(offset, 5) = 1;
} else {
if (location == LOCATION_DECK || location == LOCATION_HAND || location == LOCATION_EXTRA) {
if (hide || (c.position_ & POS_FACEDOWN)) {
mask(offset, 4) = 1;
}
} else {
mask(offset, 4) = 1;
}
}
if (!hide) {
mask(offset, 6) = 1;
mask(offset, 7) = 1;
mask(offset, 8) = 1;
mask(offset, 9) = 1;
mask(offset, 10) = 1;
mask(offset, 11) = 1;
mask(offset, 12) = 1;
mask(offset, 13) = 1;
}
}
void _set_obs_global(TArray<uint8_t> &feat, PlayerId player, const std::vector<int> &loc_n_cards) { void _set_obs_global(TArray<uint8_t> &feat, PlayerId player, const std::vector<int> &loc_n_cards) {
uint8_t me = player; uint8_t me = player;
uint8_t op = 1 - player; uint8_t op = 1 - player;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment