Commit 72a1fd28 authored by sbl1996@126.com's avatar sbl1996@126.com

Add oppo_info

parent 5cd9807d
......@@ -23,9 +23,10 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.utils import RecordEpisodeStatistics, EnvPreprocess
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.utils import masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
......@@ -356,6 +357,7 @@ def rollout(
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = EnvPreprocess(envs, skip_mask=not args.m1.oppo_info)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
......@@ -363,6 +365,7 @@ def rollout(
local_seed + 100000,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = EnvPreprocess(eval_envs, skip_mask=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
......@@ -440,9 +443,6 @@ def rollout(
init_rstates = []
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.stack(xs), *storage)
......@@ -566,7 +566,7 @@ def rollout(
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
k: jax.device_put_sharded(v, devices=learner_devices) if v is not None else None
for k, v in x.items()
}
elif x is not None:
......
This diff is collapsed.
......@@ -60,15 +60,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
observations, infos = self.env.reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
class EnvPreprocess(gym.Wrapper):
def __init__(self, env, skip_mask):
super().__init__(env)
self.skip_mask = skip_mask
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
if self.skip_mask:
observations['mask_'] = None
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
if self.skip_mask:
observations['mask_'] = None
return (
observations,
rewards,
terminated,
truncated,
infos,
)
\ No newline at end of file
......@@ -85,7 +85,8 @@ class CardEncoder(nn.Module):
version: int = 0
@nn.compact
def __call__(self, x_id, x):
def __call__(self, x_id, x, mask):
assert self.version > 0
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
......@@ -136,18 +137,35 @@ class CardEncoder(nn.Module):
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
f_cards_g = None
else:
x_id = mlp((c,), kernel_init=default_fc_init2)(x_id)
x_id = jax.nn.swish(x_id)
f_loc = embed(9, c // 16 * 2)(x_loc)
f_seq = embed(76, c // 16 * 2)(x_seq)
x_cards = jnp.concatenate([
f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type], axis=-1)
feats_g = [
x_id, f_loc, f_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type]
if mask is not None:
assert len(feats_g) == mask.shape[-1]
feats = [
jnp.where(mask[..., i:i+1] == 1, f, f[..., -1:, :])
for i, f in enumerate(feats_g)
]
else:
feats = feats_g
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = mlp((c,), kernel_init=default_fc_init2)(x_cards)
x_cards = x_cards * x_id
x_cards = x_cards * feats[0]
f_cards = layer_norm()(x_cards)
return f_cards, c_mask
if self.oppo_info:
x_cards_g = jnp.concatenate(feats_g[1:], axis=-1)
x_cards_g = mlp((c,), kernel_init=default_fc_init2)(x_cards_g)
x_cards_g = x_cards_g * feats_g[0]
f_cards_g = layer_norm()(x_cards_g)
else:
f_cards_g = None
return f_cards_g, f_cards, c_mask
class GlobalEncoder(nn.Module):
......@@ -229,35 +247,26 @@ class Encoder(nn.Module):
id_embed = embed(n_embed, embed_dim)
card_encoder = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, version=self.version)
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype,
version=self.version, oppo_info=self.oppo_info)
ActionEncoderCls = ActionEncoder if self.version == 0 else ActionEncoderV1
action_encoder = ActionEncoderCls(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards_g = x['g_cards_'] if self.oppo_info else None
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
mask = x['mask_']
batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0
n_cards = x_cards.shape[-2]
if self.oppo_info:
x_cards = jnp.concatenate([x_cards, x_cards_g], axis=-2)
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
if self.freeze_id:
x_id = jax.lax.stop_gradient(x_id)
f_cards, c_mask = card_encoder(x_id, x_cards[:, :, 2:])
if self.oppo_info:
f_cards_me, f_cards_g = jnp.split(f_cards, [n_cards], axis=-2)
else:
f_cards_me, f_cards_g = f_cards, None
f_cards_g, f_cards_me, c_mask = card_encoder(x_id, x_cards[:, :, 2:], mask)
# Cards
fs_g_card = []
......@@ -526,19 +535,18 @@ class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, rstate1, rstate2, g_cards):
f_state = jnp.concatenate([rstate1[0], rstate1[1], rstate2[0], rstate2[0]], axis=-1)
def __call__(self, f_state_r1, f_state_r2, f_state, g_cards):
f_state = jnp.concatenate([f_state_r1, f_state_r2, f_state, g_cards], axis=-1)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=True)(f_state)
c = self.channels[-1]
t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
s, b = jnp.split(t, 2, axis=-1)
x = x * s + b
x = mlp([c], last_lin=False)(x)
# c = self.channels[-1]
# t = nn.Dense(c * 2, dtype=self.dtype, param_dtype=self.param_dtype)(g_cards)
# s, b = jnp.split(t, 2, axis=-1)
# x = x * s + b
# x = mlp([c], last_lin=False)(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
......@@ -720,9 +728,11 @@ class RNNAgent(nn.Module):
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
logits = actor(f_state_r, f_actions, mask)
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
if self.oppo_info:
critic = GlobalCritic(
channels=[c, c], dtype=self.dtype, param_dtype=self.param_dtype)
if not multi_step:
if isinstance(rstate[0], tuple):
rstate1_t, rstate2_t = rstate
......@@ -735,12 +745,9 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x1, x2), rstate1, rstate2)
rstate2_t = jax.tree.map(
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g)
f_critic = jnp.concatenate([rstate1_t[1], rstate2_t[1], f_state, f_g], axis=-1)
value = critic(f_critic, train)
else:
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train)
if self.int_head:
......
This diff is collapsed.
......@@ -10,8 +10,6 @@ import optax
import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x))
......
import re
import numpy as np
import gymnasium as gym
import pickle
import optree
import torch
from ygoai.rl.env import RecordEpisodeStatistics
from ygoai.rl.env import RecordEpisodeStatistics, EnvPreprocess
def split_param_groups(model, regex):
......
......@@ -1540,7 +1540,7 @@ public:
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"obs:g_cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 41})),
"obs:mask_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 14})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
......@@ -2337,9 +2337,18 @@ public:
return;
}
auto [spec_infos, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
if (spec_.config["oppo_info"_]) {
_set_obs_g_cards(state["obs:g_cards_"_]);
_set_obs_g_cards(state["obs:cards_"_], to_play_);
auto [spec_infos_, loc_n_cards_] = _set_obs_mask(state["obs:mask_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
} else {
auto [spec_infos_, loc_n_cards_] = _set_obs_cards(state["obs:cards_"_], to_play_);
spec_infos = spec_infos_;
loc_n_cards = loc_n_cards_;
}
_set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
......@@ -2448,27 +2457,85 @@ private:
return {spec_infos, loc_n_cards};
}
void _set_obs_g_cards(TArray<uint8_t> &f_cards) {
void _set_obs_g_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
std::vector<uint8_t> configs = {
LOCATION_DECK, LOCATION_HAND, LOCATION_MZONE,
LOCATION_SZONE, LOCATION_GRAVE, LOCATION_REMOVED,
LOCATION_EXTRA,
};
for (auto location : configs) {
std::vector<Card> cards = get_cards_in_location(pi, location);
std::vector<Card> cards = get_cards_in_location(player, location);
int n_cards = cards.size();
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
CardId card_id = c_get_card_id(c.code_);
_set_obs_card_(f_cards, offset, c, false, card_id, false);
offset++;
if (offset == (spec_.config["max_cards"_] * 2 - 1)) {
return;
}
}
}
}
}
std::tuple<SpecInfos, std::vector<int>> _set_obs_mask(TArray<uint8_t> &mask, PlayerId to_play) {
SpecInfos spec_infos;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1;
std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false},
{LOCATION_GRAVE, false}, {LOCATION_REMOVED, false},
{LOCATION_EXTRA, true},
};
for (auto &[location, hidden_for_opponent] : configs) {
// check this
if (opponent && (revealed_.size() != 0)) {
hidden_for_opponent = false;
}
if (opponent && hidden_for_opponent) {
auto n_cards = YGO_QueryFieldCount(pduel_, player, location);
loc_n_cards.push_back(n_cards);
for (auto i = 0; i < n_cards; i++) {
mask(offset, 1) = 1;
mask(offset, 3) = 1;
offset++;
}
} else {
std::vector<Card> cards = get_cards_in_location(player, location);
int n_cards = cards.size();
loc_n_cards.push_back(n_cards);
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i];
auto spec = c.get_spec(opponent);
bool hide = false;
if (opponent) {
hide = c.position_ & POS_FACEDOWN;
if (revealed_.find(spec) != revealed_.end()) {
hide = false;
}
}
CardId card_id = 0;
if (!hide) {
card_id = c_get_card_id(c.code_);
}
_set_obs_mask_(mask, offset, c, hide);
offset++;
spec_infos[spec] = {static_cast<uint16_t>(offset), card_id};
}
}
}
}
return {spec_infos, loc_n_cards};
}
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) {
......@@ -2531,6 +2598,54 @@ private:
}
}
void _set_obs_mask_(TArray<uint8_t> &mask, int offset, const Card &c,
bool hide, CardId card_id = 0, bool global = false) {
// check offset exceeds max_cards
uint8_t location = c.location_;
bool overlay = location & LOCATION_OVERLAY;
if (overlay) {
location = location & 0x7f;
}
if (overlay) {
hide = false;
}
if (!hide) {
if (card_id != 0) {
mask(offset, 0) = 1;
}
}
mask(offset, 1) = 1;
if (location == LOCATION_MZONE || location == LOCATION_SZONE ||
location == LOCATION_GRAVE) {
mask(offset, 2) = 1;
}
mask(offset, 3) = 1;
if (overlay) {
mask(offset, 4) = 1;
mask(offset, 5) = 1;
} else {
if (location == LOCATION_DECK || location == LOCATION_HAND || location == LOCATION_EXTRA) {
if (hide || (c.position_ & POS_FACEDOWN)) {
mask(offset, 4) = 1;
}
} else {
mask(offset, 4) = 1;
}
}
if (!hide) {
mask(offset, 6) = 1;
mask(offset, 7) = 1;
mask(offset, 8) = 1;
mask(offset, 9) = 1;
mask(offset, 10) = 1;
mask(offset, 11) = 1;
mask(offset, 12) = 1;
mask(offset, 13) = 1;
}
}
void _set_obs_global(TArray<uint8_t> &feat, PlayerId player, const std::vector<int> &loc_n_cards) {
uint8_t me = player;
uint8_t op = 1 - player;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment