Commit e1ff8f92 authored by sbl1996@126.com's avatar sbl1996@126.com

Add more rnn options and batch norm

parent 974fe861
......@@ -19,14 +19,13 @@ import numpy as np
import optax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
......@@ -251,6 +250,14 @@ def create_agent(args, eval=False):
)
def get_variables(agent_state):
batch_stats = getattr(agent_state, "batch_stats", None)
variables = {'params': agent_state.params}
if batch_stats is not None:
variables['batch_stats'] = batch_stats
return variables
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
......@@ -502,11 +509,9 @@ def rollout(
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_obs, next_rstate), next_main))
(init_rstate1, init_rstate2, next_obs, next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
......@@ -683,13 +688,17 @@ def main():
agent = create_agent(args)
rstate = agent.init_rnn_state(1)
params = agent.init(init_key, sample_obs, rstate)
variables = agent.init(init_key, sample_obs, rstate)
variables = flax.core.unfreeze(variables)
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
variables['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
# variables = flax.core.freeze(variables)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
variables = flax.serialization.from_bytes(variables, f.read())
print(f"loaded checkpoint from {args.checkpoint}")
tx = optax.MultiSteps(
optax.chain(
......@@ -701,29 +710,29 @@ def main():
every_k_schedule=1,
)
tx = optax.apply_if_finite(tx, max_consecutive_errors=10)
if 'batch_stats' not in variables:
# variables = flax.core.unfreeze(variables)
variables['batch_stats'] = {}
# variables = flax.core.freeze(variables)
agent_state = TrainState.create(
apply_fn=None,
params=params,
params=variables['params'],
tx=tx,
batch_stats=variables['batch_stats'],
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, sample_obs, eval_rstate)
eval_variables = eval_agent.init(init_key, sample_obs, eval_rstate)
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(eval_params, f.read())
eval_variables = flax.serialization.from_bytes(eval_variables, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
eval_variables = None
def advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
......@@ -811,17 +820,29 @@ def main():
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains):
def apply_fn(variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains):
if args.switch:
dones = dones | next_dones
(rstate1, rstate2), new_logits, new_values = agent.apply(
params, obs, (rstate1, rstate2), dones, switch_or_mains)[:3]
((rstate1, rstate2), new_logits, new_values, _), state_updates = agent.apply(
variables, obs, (rstate1, rstate2), dones, switch_or_mains,
train=True, mutable=["batch_stats"])
new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values)
return (rstate1, rstate2), new_logits, new_values
return ((rstate1, rstate2), new_logits, new_values), state_updates
def compute_next_value(
variables, rstate1, rstate2, next_obs, next_main):
rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2)
next_value = agent.apply(variables, next_obs, rstate)[2]
next_value = jax.tree.map(lambda x: x.squeeze(-1), next_value)
next_value = jax.lax.stop_gradient(next_value)
sign = -1 if args.switch else 1
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
return next_value
def compute_advantage(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value):
variables, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_obs, next_main):
segment_length = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
......@@ -829,8 +850,11 @@ def main():
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3]
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
......@@ -842,10 +866,11 @@ def main():
return target_values, advantages
def compute_loss(
params, rstate1, rstate2, obs, dones, next_dones,
params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask):
(rstate1, rstate2), new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
new_logits, new_values, actions, logits, target_values, advantages,
......@@ -854,14 +879,19 @@ def main():
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl, rstate1, rstate2 = jax.tree.map(
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def compute_advantage_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value, mask):
num_envs = jax.tree.leaves(next_value)[0].shape[0]
new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3]
params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_obs, next_main):
num_envs = jax.tree.leaves(next_main)[0].shape[0]
variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
variables = {'params': params, 'batch_stats': state_updates['batch_stats']}
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
......@@ -873,22 +903,21 @@ def main():
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl = jax.lax.stop_gradient(approx_kl)
return loss, (pg_loss, v_loss, ent_loss, approx_kl)
return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl)
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_obs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs, init_rstate1, init_rstate2 = [
next_obs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
for x in [sharded_next_obs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main = jnp.concatenate(sharded_next_main)
......@@ -913,49 +942,45 @@ def main():
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = agent.apply(agent_state.params, *next_inputs)[2]
next_value = jax.tree.map(lambda x: jnp.squeeze(x, axis=-1), next_value)
sign = -1 if args.switch else 1
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
def convert_data(x: jnp.ndarray, multi_step=True):
key = subkey if args.update_epochs > 1 else None
return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2))
shuffled_storage = jax.tree.map(convert_data, storage)
b_init_rstate1, b_init_rstate2, b_next_obs, b_next_main = \
jax.tree.map(partial(convert_data, multi_step=False),
(init_rstate1, init_rstate2, next_obs, next_main))
b_storage = jax.tree.map(convert_data, storage)
if args.switch:
switch_or_mains = convert_data(switch)
else:
switch_or_mains = shuffled_storage.mains
shuffled_mask = ~shuffled_storage.dones
shuffled_next_value = jax.tree.map(
partial(convert_data, multi_step=False), next_value)
shuffled_rewards = shuffled_storage.rewards
switch_or_mains = b_storage.mains
b_mask = ~b_storage.dones
b_rewards = b_storage.rewards
if args.segment_length is None:
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
(loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl)), grads = \
loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
else:
def update_minibatch(carry, minibatch):
def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry
minibatch_t = rstate1, rstate2, *minibatch_t
(loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \
grads = loss_grad_fn(agent_state.params, *minibatch_t)
(loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \
grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t, mask = minibatch
target_values, advantages = compute_advantage(
carry.params, rstate1, rstate2, *minibatch_t)
get_variables(carry), rstate1, rstate2, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(carry, _rstate1, _rstate2), \
......@@ -967,17 +992,18 @@ def main():
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
b_init_rstate1,
b_init_rstate2,
b_storage.obs,
b_storage.dones,
b_storage.next_dones,
switch_or_mains,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_rewards,
shuffled_next_value,
shuffled_mask,
b_storage.actions,
b_storage.logits,
b_rewards,
b_mask,
b_next_obs,
b_next_main,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
......@@ -1007,16 +1033,16 @@ def main():
params_queues = []
rollout_queues = []
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
unreplicated_params = flax.jax_utils.unreplicate(get_variables(agent_state))
for d_idx, d_id in enumerate(args.actor_device_ids):
actor_device = local_devices[d_id]
device_params = jax.device_put(unreplicated_params, actor_device)
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_params:
if eval_variables:
params_queues[-1].put(
jax.device_put(eval_params, actor_device))
jax.device_put(eval_variables, actor_device))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread(
target=rollout,
......@@ -1070,7 +1096,7 @@ def main():
*list(zip(*sharded_data_list)),
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
unreplicated_params = flax.jax_utils.unreplicate(get_variables(agent_state))
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
......
......@@ -8,7 +8,7 @@ import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.modules import MLP, GLUMlp, BatchRenorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
......@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
def __call__(self, f_state, train):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
......@@ -495,6 +495,33 @@ class Critic(nn.Module):
return x
class CrossCritic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
# dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, train):
x = f_state.astype(self.dtype)
linear = partial(nn.Dense, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False)
BN = partial(
BatchRenorm, dtype=self.dtype, param_dtype=self.param_dtype,
momentum=self.batch_norm_momentum, axis_name="local_devices",
use_running_average=not train)
x = BN()(x)
for c in self.channels:
x = linear(c)(x)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x = nn.relu()(x)
# x = nn.leaky_relu(x, negative_slope=0.1)
x = BN()(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype)(x)
return x
class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None
......@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor"""
oppo_info: bool = False
"""whether to use opponent's information"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
batch_norm: bool = False
"""whether to use batch normalization for the critic"""
critic_width: int = 128
"""the width of the critic"""
critic_depth: int = 3
"""the depth of the critic"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
......@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size: int = 32
action_feats: bool = True
oppo_info: bool = False
rnn_shortcut: bool = False
batch_norm: bool = False
critic_width: int = 128
critic_depth: int = 3
version: int = 0
switch: bool = True
......@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None):
def __call__(self, x, rstate, done=None, switch_or_main=None, train=False):
batch_size = jax.tree.leaves(rstate)[0].shape[0]
c = self.num_channels
......@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main)
if self.rnn_shortcut:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
......@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g)
else:
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r)
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train)
if self.int_head:
cs = [self.critic_width] * self.critic_depth
critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r)
value = (value, value_int)
return rstate, logits, value, valid
......
from typing import Tuple, Union, Optional
from typing import Tuple, Union, Optional, Any
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.normalization import _compute_stats, _normalize, _canonicalize_axes
def decode_id(x):
......@@ -110,3 +111,144 @@ class RMSNorm(nn.Module):
)
x = x * scale
return jnp.asarray(x, self.dtype)
class ReZero(nn.Module):
channel_wise: bool = False
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
shape = (x.shape[-1],) if self.channel_wise else ()
scale = self.param("scale", nn.initializers.zeros, shape, self.param_dtype)
return x * scale
class BatchRenorm(nn.Module):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.999
epsilon: float = 0.001
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: nn.initializers.Initializer = nn.initializers.zeros
scale_init: nn.initializers.Initializer = nn.initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
@nn.compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = nn.merge_param(
'use_running_average', self.use_running_average, use_running_average
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]
ra_mean = self.variable(
'batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape)
ra_var = self.variable(
'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape)
r_max = self.variable('batch_stats', 'r_max', lambda s: s, 3)
d_max = self.variable('batch_stats', 'd_max', lambda s: s, 5)
steps = self.variable('batch_stats', 'steps', lambda s: s, 0)
if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
# The code below is implemented following the Batch Renormalization paper
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32)
custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var
custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean
ra_mean.value = (
self.momentum * ra_mean.value + (1 - self.momentum) * mean
)
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
steps.value += 1
return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
from typing import Any, Callable
import jax
import jax.numpy as jnp
from flax import core, struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
import optax
import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics
......@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count = tot_count
return new_mean, new_var, new_count
class TrainState(struct.PyTreeNode):
step: int
apply_fn: Callable = struct.field(pytree_node=False)
params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
def apply_gradients(self, *, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if OVERWRITE_WITH_GRADIENT in grads:
grads_with_opt = grads['params']
params_with_opt = self.params['params']
else:
grads_with_opt = grads
params_with_opt = self.params
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
new_params_with_opt = optax.apply_updates(params_with_opt, updates)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if OVERWRITE_WITH_GRADIENT in grads:
new_params = {
'params': new_params_with_opt,
OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
}
else:
new_params = new_params_with_opt
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt = (
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
)
opt_state = tx.init(params_with_opt)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment