Commit e1ff8f92 authored by sbl1996@126.com's avatar sbl1996@126.com

Add more rnn options and batch norm

parent 974fe861
...@@ -19,14 +19,13 @@ import numpy as np ...@@ -19,14 +19,13 @@ import numpy as np
import optax import optax
import distrax import distrax
import tyro import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \ from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
...@@ -251,6 +250,14 @@ def create_agent(args, eval=False): ...@@ -251,6 +250,14 @@ def create_agent(args, eval=False):
) )
def get_variables(agent_state):
batch_stats = getattr(agent_state, "batch_stats", None)
variables = {'params': agent_state.params}
if batch_stats is not None:
variables['batch_stats'] = batch_stats
return variables
def init_rnn_state(num_envs, rnn_channels): def init_rnn_state(num_envs, rnn_channels):
return ( return (
np.zeros((num_envs, rnn_channels)), np.zeros((num_envs, rnn_channels)),
...@@ -502,11 +509,9 @@ def rollout( ...@@ -502,11 +509,9 @@ def rollout(
sharded_storage.append(x) sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage) sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_obs, next_rstate), next_main)) (init_rstate1, init_rstate2, next_obs, next_main))
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
_start = time.time() _start = time.time()
...@@ -683,13 +688,17 @@ def main(): ...@@ -683,13 +688,17 @@ def main():
agent = create_agent(args) agent = create_agent(args)
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = agent.init(init_key, sample_obs, rstate) variables = agent.init(init_key, sample_obs, rstate)
variables = flax.core.unfreeze(variables)
if embeddings is not None: if embeddings is not None:
unknown_embed = embeddings.mean(axis=0) unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0) embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params) variables['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings) # variables = flax.core.freeze(variables)
params = flax.core.freeze(params) if args.checkpoint:
with open(args.checkpoint, "rb") as f:
variables = flax.serialization.from_bytes(variables, f.read())
print(f"loaded checkpoint from {args.checkpoint}")
tx = optax.MultiSteps( tx = optax.MultiSteps(
optax.chain( optax.chain(
...@@ -701,29 +710,29 @@ def main(): ...@@ -701,29 +710,29 @@ def main():
every_k_schedule=1, every_k_schedule=1,
) )
tx = optax.apply_if_finite(tx, max_consecutive_errors=10) tx = optax.apply_if_finite(tx, max_consecutive_errors=10)
if 'batch_stats' not in variables:
# variables = flax.core.unfreeze(variables)
variables['batch_stats'] = {}
# variables = flax.core.freeze(variables)
agent_state = TrainState.create( agent_state = TrainState.create(
apply_fn=None, apply_fn=None,
params=params, params=variables['params'],
tx=tx, tx=tx,
batch_stats=variables['batch_stats'],
) )
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint: if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True) eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1) eval_rstate = eval_agent.init_rnn_state(1)
eval_params = eval_agent.init(init_key, sample_obs, eval_rstate) eval_variables = eval_agent.init(init_key, sample_obs, eval_rstate)
with open(args.eval_checkpoint, "rb") as f: with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(eval_params, f.read()) eval_variables = flax.serialization.from_bytes(eval_variables, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}") print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else: else:
eval_params = None eval_variables = None
def advantage_fn( def advantage_fn(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
...@@ -811,17 +820,29 @@ def main(): ...@@ -811,17 +820,29 @@ def main():
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, pg_loss, v_loss, ent_loss, approx_kl return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains): def apply_fn(variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains):
if args.switch: if args.switch:
dones = dones | next_dones dones = dones | next_dones
(rstate1, rstate2), new_logits, new_values = agent.apply( ((rstate1, rstate2), new_logits, new_values, _), state_updates = agent.apply(
params, obs, (rstate1, rstate2), dones, switch_or_mains)[:3] variables, obs, (rstate1, rstate2), dones, switch_or_mains,
train=True, mutable=["batch_stats"])
new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values) new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values)
return (rstate1, rstate2), new_logits, new_values return ((rstate1, rstate2), new_logits, new_values), state_updates
def compute_next_value(
variables, rstate1, rstate2, next_obs, next_main):
rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2)
next_value = agent.apply(variables, next_obs, rstate)[2]
next_value = jax.tree.map(lambda x: x.squeeze(-1), next_value)
next_value = jax.lax.stop_gradient(next_value)
sign = -1 if args.switch else 1
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
return next_value
def compute_advantage( def compute_advantage(
params, rstate1, rstate2, obs, dones, next_dones, variables, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value): switch_or_mains, actions, logits, rewards, next_obs, next_main):
segment_length = dones.shape[0] segment_length = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \ obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
...@@ -829,8 +850,11 @@ def main(): ...@@ -829,8 +850,11 @@ def main():
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards)) (obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
new_logits, new_values = apply_fn( ((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3] variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
target_values, advantages = advantage_fn( target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
...@@ -842,10 +866,11 @@ def main(): ...@@ -842,10 +866,11 @@ def main():
return target_values, advantages return target_values, advantages
def compute_loss( def compute_loss(
params, rstate1, rstate2, obs, dones, next_dones, params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask): switch_or_mains, actions, logits, target_values, advantages, mask):
(rstate1, rstate2), new_logits, new_values = apply_fn( variables = {'params': params, 'batch_stats': batch_stats}
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains) ((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn( loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
new_logits, new_values, actions, logits, target_values, advantages, new_logits, new_values, actions, logits, target_values, advantages,
...@@ -854,14 +879,19 @@ def main(): ...@@ -854,14 +879,19 @@ def main():
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss) loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl, rstate1, rstate2 = jax.tree.map( approx_kl, rstate1, rstate2 = jax.tree.map(
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2)) jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2) return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def compute_advantage_loss( def compute_advantage_loss(
params, rstate1, rstate2, obs, dones, next_dones, params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value, mask): switch_or_mains, actions, logits, rewards, mask, next_obs, next_main):
num_envs = jax.tree.leaves(next_value)[0].shape[0] num_envs = jax.tree.leaves(next_main)[0].shape[0]
new_logits, new_values = apply_fn( variables = {'params': params, 'batch_stats': batch_stats}
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3] ((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
variables = {'params': params, 'batch_stats': state_updates['batch_stats']}
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
target_values, advantages = advantage_fn( target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
...@@ -873,22 +903,21 @@ def main(): ...@@ -873,22 +903,21 @@ def main():
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss) loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl = jax.lax.stop_gradient(approx_kl) approx_kl = jax.lax.stop_gradient(approx_kl)
return loss, (pg_loss, v_loss, ent_loss, approx_kl) return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl)
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List, sharded_storages: List,
sharded_init_rstate1: List, sharded_init_rstate1: List,
sharded_init_rstate2: List, sharded_init_rstate2: List,
sharded_next_inputs: List, sharded_next_obs: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
): ):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2 next_obs, init_rstate1, init_rstate2 = [
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_obs, sharded_init_rstate1, sharded_init_rstate2]
] ]
next_main = jnp.concatenate(sharded_next_main) next_main = jnp.concatenate(sharded_next_main)
...@@ -913,49 +942,45 @@ def main(): ...@@ -913,49 +942,45 @@ def main():
agent_state, key = carry agent_state, key = carry
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
next_value = agent.apply(agent_state.params, *next_inputs)[2]
next_value = jax.tree.map(lambda x: jnp.squeeze(x, axis=-1), next_value)
sign = -1 if args.switch else 1
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
def convert_data(x: jnp.ndarray, multi_step=True): def convert_data(x: jnp.ndarray, multi_step=True):
key = subkey if args.update_epochs > 1 else None key = subkey if args.update_epochs > 1 else None
return reshape_minibatch( return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key) x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map( b_init_rstate1, b_init_rstate2, b_next_obs, b_next_main = \
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2)) jax.tree.map(partial(convert_data, multi_step=False),
shuffled_storage = jax.tree.map(convert_data, storage) (init_rstate1, init_rstate2, next_obs, next_main))
b_storage = jax.tree.map(convert_data, storage)
if args.switch: if args.switch:
switch_or_mains = convert_data(switch) switch_or_mains = convert_data(switch)
else: else:
switch_or_mains = shuffled_storage.mains switch_or_mains = b_storage.mains
shuffled_mask = ~shuffled_storage.dones b_mask = ~b_storage.dones
shuffled_next_value = jax.tree.map( b_rewards = b_storage.rewards
partial(convert_data, multi_step=False), next_value)
shuffled_rewards = shuffled_storage.rewards
if args.segment_length is None: if args.segment_length is None:
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn( (loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl)), grads = \
agent_state.params, *minibatch) loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices") grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads) agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
else: else:
def update_minibatch(carry, minibatch): def update_minibatch(carry, minibatch):
def update_minibatch_t(carry, minibatch_t): def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry agent_state, rstate1, rstate2 = carry
minibatch_t = rstate1, rstate2, *minibatch_t minibatch_t = rstate1, rstate2, *minibatch_t
(loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \ (loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \
grads = loss_grad_fn(agent_state.params, *minibatch_t) grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices") grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads) agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl) return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t, mask = minibatch rstate1, rstate2, *minibatch_t, mask = minibatch
target_values, advantages = compute_advantage( target_values, advantages = compute_advantage(
carry.params, rstate1, rstate2, *minibatch_t) get_variables(carry), rstate1, rstate2, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(carry, _rstate1, _rstate2), \ (carry, _rstate1, _rstate2), \
...@@ -967,17 +992,18 @@ def main(): ...@@ -967,17 +992,18 @@ def main():
update_minibatch, update_minibatch,
agent_state, agent_state,
( (
shuffled_init_rstate1, b_init_rstate1,
shuffled_init_rstate2, b_init_rstate2,
shuffled_storage.obs, b_storage.obs,
shuffled_storage.dones, b_storage.dones,
shuffled_storage.next_dones, b_storage.next_dones,
switch_or_mains, switch_or_mains,
shuffled_storage.actions, b_storage.actions,
shuffled_storage.logits, b_storage.logits,
shuffled_rewards, b_rewards,
shuffled_next_value, b_mask,
shuffled_mask, b_next_obs,
b_next_main,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
...@@ -1007,16 +1033,16 @@ def main(): ...@@ -1007,16 +1033,16 @@ def main():
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(get_variables(agent_state))
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
actor_device = local_devices[d_id] actor_device = local_devices[d_id]
device_params = jax.device_put(unreplicated_params, actor_device) device_params = jax.device_put(unreplicated_params, actor_device)
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1)) params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
if eval_params: if eval_variables:
params_queues[-1].put( params_queues[-1].put(
jax.device_put(eval_params, actor_device)) jax.device_put(eval_variables, actor_device))
actor_thread_id = d_idx * args.num_actor_threads + thread_id actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread( threading.Thread(
target=rollout, target=rollout,
...@@ -1070,7 +1096,7 @@ def main(): ...@@ -1070,7 +1096,7 @@ def main():
*list(zip(*sharded_data_list)), *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
) )
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(get_variables(agent_state))
params_queue_put_time = 0 params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id]) device_params = jax.device_put(unreplicated_params, local_devices[d_id])
......
...@@ -8,7 +8,7 @@ import jax.numpy as jnp ...@@ -8,7 +8,7 @@ import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, GLUMlp, RMSNorm, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, GLUMlp, BatchRenorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.rwkv import Rwkv6SelfAttention from ygoai.rl.jax.rwkv import Rwkv6SelfAttention
...@@ -487,7 +487,7 @@ class Critic(nn.Module): ...@@ -487,7 +487,7 @@ class Critic(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, f_state): def __call__(self, f_state, train):
f_state = f_state.astype(self.dtype) f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype) mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state) x = mlp(self.channels, last_lin=False)(f_state)
...@@ -495,6 +495,33 @@ class Critic(nn.Module): ...@@ -495,6 +495,33 @@ class Critic(nn.Module):
return x return x
class CrossCritic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
# dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, train):
x = f_state.astype(self.dtype)
linear = partial(nn.Dense, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False)
BN = partial(
BatchRenorm, dtype=self.dtype, param_dtype=self.param_dtype,
momentum=self.batch_norm_momentum, axis_name="local_devices",
use_running_average=not train)
x = BN()(x)
for c in self.channels:
x = linear(c)(x)
# if self.use_layer_norm:
# x = nn.LayerNorm()(x)
x = nn.relu()(x)
# x = nn.leaky_relu(x, negative_slope=0.1)
x = BN()(x)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype)(x)
return x
class GlobalCritic(nn.Module): class GlobalCritic(nn.Module):
channels: Sequence[int] = (128, 128) channels: Sequence[int] = (128, 128)
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs): ...@@ -580,6 +607,14 @@ class ModelArgs(EncoderArgs):
"""whether to use FiLM for the actor""" """whether to use FiLM for the actor"""
oppo_info: bool = False oppo_info: bool = False
"""whether to use opponent's information""" """whether to use opponent's information"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
batch_norm: bool = False
"""whether to use batch normalization for the critic"""
critic_width: int = 128
"""the width of the critic"""
critic_depth: int = 3
"""the depth of the critic"""
rwkv_head_size: int = 32 rwkv_head_size: int = 32
"""the head size for the RWKV""" """the head size for the RWKV"""
...@@ -596,6 +631,10 @@ class RNNAgent(nn.Module): ...@@ -596,6 +631,10 @@ class RNNAgent(nn.Module):
rwkv_head_size: int = 32 rwkv_head_size: int = 32
action_feats: bool = True action_feats: bool = True
oppo_info: bool = False oppo_info: bool = False
rnn_shortcut: bool = False
batch_norm: bool = False
critic_width: int = 128
critic_depth: int = 3
version: int = 0 version: int = 0
switch: bool = True switch: bool = True
...@@ -606,7 +645,7 @@ class RNNAgent(nn.Module): ...@@ -606,7 +645,7 @@ class RNNAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None, train=False):
batch_size = jax.tree.leaves(rstate)[0].shape[0] batch_size = jax.tree.leaves(rstate)[0].shape[0]
c = self.num_channels c = self.num_channels
...@@ -669,6 +708,10 @@ class RNNAgent(nn.Module): ...@@ -669,6 +708,10 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main( rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main) rnn_layer, rstate, f_state, done, switch_or_main)
if self.rnn_shortcut:
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
if self.film: if self.film:
actor = FiLMActor( actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
...@@ -694,13 +737,16 @@ class RNNAgent(nn.Module): ...@@ -694,13 +737,16 @@ class RNNAgent(nn.Module):
lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2) lambda x1, x2: jnp.where(main, x2, x1), rstate1, rstate2)
value = critic(rstate1_t, rstate2_t, f_g) value = critic(rstate1_t, rstate2_t, f_g)
else: else:
critic = Critic( CriticCls = CrossCritic if self.batch_norm else Critic
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) cs = [self.critic_width] * self.critic_depth
value = critic(f_state_r) critic = CriticCls(
channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value = critic(f_state_r, train)
if self.int_head: if self.int_head:
cs = [self.critic_width] * self.critic_depth
critic_int = Critic( critic_int = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=cs, dtype=self.dtype, param_dtype=self.param_dtype)
value_int = critic_int(f_state_r) value_int = critic_int(f_state_r)
value = (value, value_int) value = (value, value_int)
return rstate, logits, value, valid return rstate, logits, value, valid
......
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional, Any
import functools import functools
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from flax.linen.normalization import _compute_stats, _normalize, _canonicalize_axes
def decode_id(x): def decode_id(x):
...@@ -109,4 +110,145 @@ class RMSNorm(nn.Module): ...@@ -109,4 +110,145 @@ class RMSNorm(nn.Module):
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype "scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
) )
x = x * scale x = x * scale
return jnp.asarray(x, self.dtype) return jnp.asarray(x, self.dtype)
\ No newline at end of file
class ReZero(nn.Module):
channel_wise: bool = False
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
shape = (x.shape[-1],) if self.channel_wise else ()
scale = self.param("scale", nn.initializers.zeros, shape, self.param_dtype)
return x * scale
class BatchRenorm(nn.Module):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.999
epsilon: float = 0.001
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: nn.initializers.Initializer = nn.initializers.zeros
scale_init: nn.initializers.Initializer = nn.initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
@nn.compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = nn.merge_param(
'use_running_average', self.use_running_average, use_running_average
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]
ra_mean = self.variable(
'batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape)
ra_var = self.variable(
'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape)
r_max = self.variable('batch_stats', 'r_max', lambda s: s, 3)
d_max = self.variable('batch_stats', 'd_max', lambda s: s, 5)
steps = self.variable('batch_stats', 'steps', lambda s: s, 0)
if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
# The code below is implemented following the Batch Renormalization paper
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32)
custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var
custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean
ra_mean.value = (
self.momentum * ra_mean.value + (1 - self.momentum) * mean
)
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
steps.value += 1
return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
from typing import Any, Callable
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import core, struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
import optax
import numpy as np import numpy as np
from ygoai.rl.env import RecordEpisodeStatistics from ygoai.rl.env import RecordEpisodeStatistics
...@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments( ...@@ -67,3 +74,72 @@ def update_mean_var_count_from_moments(
new_count = tot_count new_count = tot_count
return new_mean, new_var, new_count return new_mean, new_var, new_count
class TrainState(struct.PyTreeNode):
step: int
apply_fn: Callable = struct.field(pytree_node=False)
params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
def apply_gradients(self, *, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
"""
if OVERWRITE_WITH_GRADIENT in grads:
grads_with_opt = grads['params']
params_with_opt = self.params['params']
else:
grads_with_opt = grads
params_with_opt = self.params
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
new_params_with_opt = optax.apply_updates(params_with_opt, updates)
# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if OVERWRITE_WITH_GRADIENT in grads:
new_params = {
'params': new_params_with_opt,
OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
}
else:
new_params = new_params_with_opt
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
# We exclude OWG params when present because they do not need opt states.
params_with_opt = (
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
)
opt_state = tx.init(params_with_opt)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment