Commit fcaf7bf7 authored by sbl1996@126.com's avatar sbl1996@126.com

Add nnx

parent 632f551d
import os
import shutil
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field, asdict
from types import SimpleNamespace
from typing import List, NamedTuple, Optional, Literal
from functools import partial
import ygoenv
import flax
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import distrax
import orbax.checkpoint as orbax
import tyro
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.utils import RecordEpisodeStatistics, EnvPreprocess
from ygoai.rl.ckpt import ModelCheckpoint
from ygoai.rl.jax.agent import RNNAgent as RNNAgentE
from ygoai.rl.jax.nnx.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import masked_normalize, categorical_sample, TrainState
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
ach_loss, policy_gradient_loss, vtrace, vtrace_sep, truncated_gae, truncated_gae_sep
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq: int = 0
"""the logging frequency of the deck time statistics, 0 to disable"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
timeout: int = 600
"""the timeout of the environment step"""
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
tb_offset: int = 0
"""the step offset of the tensorboard logs"""
run_name: Optional[str] = None
"""the name of the tensorboard run"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
# Algorithm specific arguments
env_id: str = "YGOPro-v1"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_steps: Optional[int] = None
"""the number of steps to compute the advantages"""
segment_length: Optional[int] = None
"""the length of the segment for training"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
switch: bool = False
"""Toggle the use of switch mechanism"""
norm_adv: bool = False
"""Toggles advantages normalization"""
burn_in_steps: Optional[int] = None
"""the number of burn-in steps for training (for R2D2)"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
sep_value: bool = True
"""Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace"
"""the method to learn the value function"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = 3.0
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold: Optional[float] = None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
vloss_clip: Optional[float] = None
"""the value loss clipping coefficient"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 1.0
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
m1: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the agent"""
m2: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the model arguments for the eval agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = False
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: Optional[bool] = None
deck_names: Optional[List[str]] = None
real_seed: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
timeout=args.timeout,
oppo_info=False,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logits: list
values: list
rewards: list
mains: list
next_dones: list
def create_agent(args, rngs=None, eval=False):
if eval:
return RNNAgentE(
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
**asdict(args.m2),
)
else:
return RNNAgent(
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
switch=args.switch,
freeze_id=args.freeze_id,
rngs=rngs,
**asdict(args.m1),
)
def get_state(agent_state):
return nnx.State.merge(agent_state.params, agent_state.batch_stats)
def reshape_minibatch(
x, multi_step, num_minibatches, num_steps, segment_length=None, key=None):
# if segment_length is None,
# n_mb = num_minibatches
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb, num_steps * (num_envs // n_mb), ...)
# else, from (num_envs, ...) to
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // n_mb_t
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
# (n_mb_e, num_envs // n_mb_e, ...)
if key is not None:
x = jax.random.permutation(key, x, axis=1 if multi_step else 0)
N = num_minibatches
if segment_length is None:
if multi_step:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
else:
M = segment_length
Nt = num_steps // M
Ne = N // Nt
if multi_step:
x = jnp.reshape(x, (Nt, M, Ne, -1) + x.shape[2:])
x = x.transpose(2, 0, 1, *range(3, x.ndim))
x = jnp.reshape(x, (Ne, Nt, -1) + x.shape[4:])
else:
x = jnp.reshape(x, (Ne, -1) + x.shape[1:])
return x
def advantage_fn(
args, next_v, values, rewards, next_dones, switch_or_mains, ratios=None, return_carry=False):
if args.switch:
if args.value == "vtrace" or args.sep_value or return_carry:
raise NotImplementedError
return gae_sep_switch(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
if args.value == "gae":
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
return adv_fn(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo, return_carry=return_carry)
else:
adv_fn = vtrace_sep if args.sep_value else vtrace
if ratios is None:
ratios = jnp.ones_like(values)
return adv_fn(
next_v, ratios, values, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max,
args.upgo, return_carry=return_carry)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue,
writer,
actor_device,
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.real_seed + device_thread_id * args.local_num_envs
np.random.seed(local_seed)
envs = make_env(
args,
local_seed,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = EnvPreprocess(envs, skip_mask=True)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
local_seed + 100000,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = EnvPreprocess(eval_envs, skip_mask=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
agent = nnx.eval_shape(lambda: create_agent(args, rngs=nnx.Rngs(0)))
agent.eval()
def apply_fn(params, *args):
return nnx.split(agent)[0].apply(params)(*args)[0]
old_eval = eval_mode != 'bot'
if old_eval:
eval_agent = create_agent(args, eval=True)
eval_apply_fn = eval_agent.apply
else:
eval_agent = nnx.eval_shape(lambda: create_agent(args, rngs=nnx.Rngs(0), eval=False))
eval_agent.eval()
_eval_apply_fn = nnx.split(eval_agent)[0].apply
def eval_apply_fn(params, *args):
return _eval_apply_fn(params)(*args)[0]
@jax.jit
def get_action(params, obs, rstate):
rstate, logits = eval_apply_fn(params, obs, rstate)[:2]
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, obs, rstate1, rstate2, main, done):
next_rstate1, logits1 = apply_fn(params1, obs, rstate1)[:2]
next_rstate2, logits2 = eval_apply_fn(params2, obs, rstate2)[:2]
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params, next_obs, rstate1, rstate2, main, done, key):
(rstate1, rstate2), logits, value = apply_fn(
params, next_obs, (rstate1, rstate2), done, main)[:3]
value = jnp.squeeze(value, axis=-1)
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, value, key
@jax.jit
def compute_advantage_carry(
next_value, values, rewards, next_dones, mains):
return advantage_fn(
args, next_value, values, rewards, next_dones, mains, return_carry=True)
deck_names = args.deck_names
deck_avg_times = {name: 0 for name in deck_names}
deck_max_times = {name: 0 for name in deck_names}
deck_time_count = {name: 0 for name in deck_names}
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = agent.init_rnn_state(args.local_num_envs)
eval_rstate1 = agent.init_rnn_state(args.local_eval_episodes)
eval_rstate2 = eval_agent.init_rnn_state(args.local_eval_episodes)
next_rstate1, next_rstate2, eval_rstate1, eval_rstate2 = \
jax.device_put([next_rstate1, next_rstate2, eval_rstate1, eval_rstate2], actor_device)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
init_rstates = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.stack(xs), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["encoder"]['id_embed']["embedding"].value.block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for k in range(start_step, args.collect_steps):
if k % args.num_steps == 0:
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
init_rstates.append((init_rstate1, init_rstate2))
global_step += args.local_num_envs * n_actors * args.world_size
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, value, key = sample_action(
params, next_obs, next_rstate1, next_rstate2, main, next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=cached_main,
actions=action,
logits=logits,
values=value,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
if args.switch:
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
if args.time_log_freq:
for i in range(2):
deck_time = info['step_time'][idx][i]
deck_name = deck_names[info['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_avg_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
max_time = max(deck_time, deck_max_times[deck_name])
deck_avg_times[deck_name] = avg_time
deck_max_times[deck_name] = max_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % args.time_log_freq == 0:
print(f"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}")
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_steps - args.num_steps
next_main = main_player == next_to_play
if args.collect_steps == args.num_steps:
storage_t = storage
storage = []
next_data = (next_obs, next_main)
else:
storage_t = storage[:args.num_steps]
storage = storage[args.num_steps:]
values, rewards, next_dones, mains = prepare_data([
(t.values, t.rewards, t.next_dones, t.mains) for t in storage])
next_value = sample_action(
params, next_obs, next_rstate1, next_rstate2, next_main, next_done, key)[-2]
next_value = jnp.where(next_main, next_value, -next_value)
adv_carry = compute_advantage_carry(
next_value, values, rewards, next_dones, mains)
next_data = adv_carry
partitioned_storage = jax.tree.map(
lambda x: jnp.split(x, len(learner_devices), axis=1), prepare_data(storage_t))
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices) if v is not None else None
for k, v in x.items()
}
elif x is not None:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
init_rstate = init_rstates.pop(0)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate, next_data))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda *x: get_action(params, *x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate2)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate1, eval_rstate2)
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
else:
eval_stats = None
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
eval_stats,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
max_episode_length = np.max(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
tb_global_step = args.tb_offset + global_step
if device_thread_id == 0:
print(
f"global_step={tb_global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), tb_global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, tb_global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, tb_global_step)
writer.add_scalar("charts/max_episode_length", max_episode_length, tb_global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), tb_global_step)
writer.add_scalar("stats/inference_time", inference_time, tb_global_step)
writer.add_scalar("stats/env_time", env_time, tb_global_step)
writer.add_scalar("charts/SPS", SPS, tb_global_step)
writer.add_scalar("charts/SPS_update", SPS_update, tb_global_step)
def main():
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.segment_length is not None:
assert args.num_steps % args.segment_length == 0, "num_steps must be divisible by segment_length"
args.collect_steps = args.collect_steps or args.num_steps
assert args.collect_steps >= args.num_steps, "collect_steps must be greater than or equal to num_steps"
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
global_main_devices = [
global_devices[process_index * len(local_devices)]
for process_index in range(args.world_size)
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
if args.run_name is None:
timestamp = int(time.time())
run_name = f"{args.exp_name}__{args.seed}__{timestamp}"
else:
run_name = args.run_name
timestamp = int(run_name.split("__")[-1])
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0 and not args.debug:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
else:
writer = dummy_writer
def save_fn(obj, path):
checkpointer.save(path, obj, force=True)
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=2)
# seeding
random.seed(args.seed)
seed = random.randint(0, 1e8)
seed_offset = args.local_rank
seed += seed_offset
init_rngs = nnx.Rngs(seed - seed_offset)
random.seed(seed)
args.real_seed = random.randint(0, 1e8)
key = jax.random.PRNGKey(args.real_seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck, deck_names = init_ygopro(args.env_id, "english", args.deck, args.code_list_file, return_deck_names=True)
args.deck_names = sorted(deck_names)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, 0, 2, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args, init_rngs)
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
agent.encoder.id_embed.embedding.value = jax.device_put(embeddings)
checkpointer = orbax.PyTreeCheckpointer()
if args.checkpoint:
graphdef, state = nnx.split(agent)
state = checkpointer.restore(args.checkpoint, item=state)
agent = nnx.merge(graphdef, state)
print(f"loaded checkpoint from {args.checkpoint}")
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
tx = optax.apply_if_finite(tx, max_consecutive_errors=10)
agent.train()
graphdef, params, batch_stats = nnx.split(agent, nnx.Param, nnx.BatchStat)
train_apply_fn = graphdef.apply
agent.eval()
eval_apply_fn = nnx.graphdef(agent).apply
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
batch_stats=batch_stats,
)
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
eval_agent = create_agent(args, eval=True)
eval_rstate = eval_agent.init_rnn_state(1)
init_key = jax.random.PRNGKey(0)
eval_variables = eval_agent.init(init_key, sample_obs, eval_rstate)
with open(args.eval_checkpoint, "rb") as f:
eval_variables = flax.serialization.from_bytes(eval_variables, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_variables = None
def compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_v):
num_envs = jax.tree.leaves(next_v)[0].shape[0]
num_steps = next_dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
ratios = reshape_time_series(ratios)
new_values_, rewards, next_dones, switch_or_mains = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch_or_mains),
)
target_values, advantages = advantage_fn(
args, next_v, new_values_, rewards, next_dones, switch_or_mains, ratios)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
return target_values, advantages
def compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None):
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (ratios - 1) - logratio
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
elif args.logits_threshold is not None:
pg_loss = ach_loss(
actions, logits, new_logits, advantages, args.logits_threshold, args.clip_coef, args.dual_clip_coef)
elif args.ppo_clip:
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
else:
pg_advs = jnp.clip(ratios, args.rho_clip_min, args.rho_clip_max) * advantages
pg_loss = policy_gradient_loss(new_logits, actions, pg_advs)
v_loss = mse_loss(new_values, target_values)
if args.vloss_clip is not None:
v_loss = jnp.minimum(v_loss, args.vloss_clip)
ent_loss = entropy_loss(new_logits)
if args.burn_in_steps:
mask = jax.tree.map(
lambda x: x.reshape(num_steps, -1), mask)
burn_in_mask = jnp.arange(num_steps) < args.burn_in_steps
mask = jnp.where(burn_in_mask[:, None], 0.0, mask)
mask = jnp.reshape(mask, (-1,))
n_valids = jnp.sum(mask)
pg_loss, v_loss, ent_loss, approx_kl = jax.tree.map(
lambda x: jnp.sum(x * mask) / n_valids, (pg_loss, v_loss, ent_loss, approx_kl))
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(
state, obs, init_rstate, dones, next_dones, switch_or_mains, train=True):
if args.switch:
dones = dones | next_dones
_apply_fn = train_apply_fn if train else eval_apply_fn
((rstate1, rstate2), new_logits, new_values, _), (_, new_state) = \
_apply_fn(state)(obs, init_rstate, dones, switch_or_mains)
batch_stats = new_state.split(nnx.Param, nnx.BatchStat)[1]
new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values)
return ((rstate1, rstate2), new_logits, new_values), batch_stats
def compute_next_value(state, next_rstate, next_obs, next_main):
rstate1, rstate2 = next_rstate
rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2)
next_value = eval_apply_fn(state)(next_obs, rstate)[0][2]
next_value = jax.tree.map(lambda x: x.squeeze(-1), next_value)
next_value = jax.lax.stop_gradient(next_value)
sign = -1 if args.switch else 1
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
return next_value
def get_advantage(
state, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_obs, next_main):
num_steps = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
jax.tree.map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
next_rstate, new_logits, new_values = apply_fn(
state, obs, init_rstate, dones, next_dones, switch_or_mains, train=False)[0]
next_value = compute_next_value(
state, next_rstate, next_obs, next_main)
target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, -1) + x.shape[2:]),
(target_values, advantages))
return target_values, advantages
def get_loss(
params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask):
state = nnx.State.merge(params, batch_stats)
((rstate1, rstate2), new_logits, new_values), batch_stats = apply_fn(
state, obs, init_rstate, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None)
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl, rstate1, rstate2 = jax.tree.map(
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (batch_stats, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def get_advantage_loss(
params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_data):
num_envs = jax.tree.leaves(next_data)[0].shape[0]
state = nnx.State.merge(params, batch_stats)
(next_rstate, new_logits, new_values), batch_stats = apply_fn(
state, obs, init_rstate, dones, next_dones, switch_or_mains)
if args.collect_steps == args.num_steps:
next_obs, next_main = next_data
state = nnx.State.merge(params, batch_stats)
next_v = compute_next_value(
state, next_rstate, next_obs, next_main)
else:
next_v = next_data
target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_v)
loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=dones.shape[0] // num_envs)
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
approx_kl = jax.lax.stop_gradient(approx_kl)
return loss, (batch_stats, pg_loss, v_loss, ent_loss, approx_kl)
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate: List,
sharded_next_data: List,
key: jax.random.PRNGKey,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_data, init_rstate = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_data, sharded_init_rstate]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
if args.switch:
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if args.segment_length is None:
loss_grad_fn = jax.value_and_grad(get_advantage_loss, has_aux=True)
else:
# TODO: fix it
loss_grad_fn = jax.value_and_grad(get_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def convert_data(x: jnp.ndarray, multi_step=True):
return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=subkey)
b_init_rstate, b_next_data = \
jax.tree.map(partial(convert_data, multi_step=False),
(init_rstate, next_data))
b_storage = jax.tree.map(convert_data, storage)
if args.switch:
switch_or_mains = convert_data(switch)
else:
switch_or_mains = b_storage.mains
b_mask = ~b_storage.dones
b_rewards = b_storage.rewards
if args.segment_length is None:
def update_minibatch(agent_state, minibatch):
(loss, (batch_stats, pg_loss, v_loss, ent_loss, approx_kl)), grads = \
loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=batch_stats)
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
else:
def update_minibatch(carry, minibatch):
def update_minibatch_t(carry, minibatch_t):
agent_state, init_rstate = carry
minibatch_t = init_rstate, *minibatch_t
(loss, (batch_stats, pg_loss, v_loss, ent_loss, approx_kl, next_rstate)), \
grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=batch_stats)
return (agent_state, next_rstate), (loss, pg_loss, v_loss, ent_loss, approx_kl)
init_rstate, *minibatch_t, mask = minibatch
target_values, advantages = get_advantage(
get_state(carry), init_rstate, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(carry, _next_rstate), \
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (carry, init_rstate), minibatch_t)
return carry, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
b_init_rstate,
b_storage.obs,
b_storage.dones,
b_storage.next_dones,
switch_or_mains,
b_storage.actions,
b_storage.logits,
b_rewards,
b_mask,
b_next_data,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
ent_loss = jax.lax.pmean(ent_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, ent_loss, approx_kl, key
all_reduce_value = jax.pmap(
lambda x: jax.lax.pmean(x, axis_name="main_devices"),
axis_name="main_devices",
devices=global_main_devices,
)
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
unreplicated_params = flax.jax_utils.unreplicate(get_state(agent_state))
for d_idx, d_id in enumerate(args.actor_device_ids):
actor_device = local_devices[d_id]
device_params = jax.device_put(unreplicated_params, actor_device)
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_variables:
params_queues[-1].put(
jax.device_put(eval_variables, actor_device))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread(
target=rollout,
args=(
jax.device_put(actor_keys[actor_thread_id], actor_device),
args,
rollout_queues[-1],
params_queues[-1],
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
actor_device,
learner_devices,
actor_thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
eval_stat_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
if eval_stats is not None:
eval_stat_list.append(eval_stats)
tb_global_step = args.tb_offset + global_step
if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0)
eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats
writer.add_scalar(f"charts/eval_return", eval_return, tb_global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, tb_global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, ent_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(get_state(agent_state))
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["encoder"]['id_embed']["embedding"].value.block_until_ready()
params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
params_queue_put_time += time.time() - params_queue_put_start
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), tb_global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
tb_global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, tb_global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), tb_global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), tb_global_step)
print(
f"{tb_global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}, "
f"put_time={params_queue_put_time:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), tb_global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/entropy", ent_loss[-1].item(), tb_global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), tb_global_step)
writer.add_scalar("losses/loss", loss, tb_global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
if __name__ == "__main__":
main()
from typing import List
import os import os
import shutil
from pathlib import Path from pathlib import Path
import zipfile import zipfile
...@@ -16,10 +19,10 @@ class ModelCheckpoint(object): ...@@ -16,10 +19,10 @@ class ModelCheckpoint(object):
""" """
def __init__(self, dirname, save_fn, n_saved=1): def __init__(self, dirname, save_fn, n_saved=1):
self._dirname = Path(dirname).expanduser() self._dirname = Path(dirname).expanduser().absolute()
self._n_saved = n_saved self._n_saved = n_saved
self._save_fn = save_fn self._save_fn = save_fn
self._saved = [] self._saved: List[Path] = []
def _check_dir(self): def _check_dir(self):
self._dirname.mkdir(parents=True, exist_ok=True) self._dirname.mkdir(parents=True, exist_ok=True)
...@@ -38,6 +41,9 @@ class ModelCheckpoint(object): ...@@ -38,6 +41,9 @@ class ModelCheckpoint(object):
if len(self._saved) > self._n_saved: if len(self._saved) > self._n_saved:
path = self._saved.pop(0) path = self._saved.pop(0)
if path.is_dir():
shutil.rmtree(path)
else:
os.remove(path) os.remove(path)
def get_latest(self): def get_latest(self):
......
...@@ -452,13 +452,14 @@ class Actor(nn.Module): ...@@ -452,13 +452,14 @@ class Actor(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact @nn.compact
def __call__(self, f_state, f_actions, mask): def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype) f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype) f_actions = f_actions.astype(self.dtype)
c = self.channels c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01)) mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=self.final_init)
f_state = mlp((c,), use_bias=True)(f_state) f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions) logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
...@@ -471,6 +472,7 @@ class FiLMActor(nn.Module): ...@@ -471,6 +472,7 @@ class FiLMActor(nn.Module):
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
noam: bool = False noam: bool = False
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact @nn.compact
def __call__(self, f_state, f_actions, mask): def __call__(self, f_state, f_actions, mask):
...@@ -486,7 +488,7 @@ class FiLMActor(nn.Module): ...@@ -486,7 +488,7 @@ class FiLMActor(nn.Module):
f_actions, mask, a_s, a_b, o_s, o_b) f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0] kernel_init=self.final_init)(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits) logits = jnp.where(mask, big_neg, logits)
return logits return logits
...@@ -647,6 +649,7 @@ class RNNAgent(nn.Module): ...@@ -647,6 +649,7 @@ class RNNAgent(nn.Module):
critic_depth: int = 3 critic_depth: int = 3
version: int = 0 version: int = 0
q_head: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
int_head: bool = False int_head: bool = False
...@@ -699,11 +702,6 @@ class RNNAgent(nn.Module): ...@@ -699,11 +702,6 @@ class RNNAgent(nn.Module):
num_steps = f_state.shape[0] // batch_size num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1 multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step: if multi_step:
f_state_r, done, switch_or_main = jax.tree.map( f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main)) lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
...@@ -722,13 +720,16 @@ class RNNAgent(nn.Module): ...@@ -722,13 +720,16 @@ class RNNAgent(nn.Module):
# f_state_r = ReZero(channel_wise=True)(f_state_r) # f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1) f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
actor_init = nn.initializers.orthogonal(1) if self.q_head else nn.initializers.orthogonal(0.01)
if self.film: if self.film:
actor = FiLMActor( actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam, final_init=actor_init)
else: else:
actor = Actor( actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, final_init=actor_init)
logits = actor(f_state_r, f_actions, mask) logits = actor(f_state_r, f_actions, mask)
if self.q_head:
return rstate, logits, valid
CriticCls = CrossCritic if self.batch_norm else Critic CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth cs = [self.critic_width] * self.critic_depth
......
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
from ygoai.rl.jax.nnx.transformer import EncoderLayer, PositionalEncoding
from ygoai.rl.jax.nnx.modules import MLP, GLUMlp, BatchRenorm, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.nnx.rnn import GRUCell, OptimizedLSTMCell
default_embed_init = nnx.initializers.uniform(scale=0.001)
default_fc_init1 = nnx.initializers.uniform(scale=0.001)
default_fc_init2 = nnx.initializers.uniform(scale=0.001)
class ActionEncoder(nnx.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
def __init__(self, channels, *, dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.channels = channels
self.dtype = dtype
self.param_dtype = param_dtype
c = self.channels
div = 8
embed = partial(
nnx.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init, rngs=rngs)
self.e_msg = embed(30, c // div)
self.e_act = embed(10, c // div)
self.e_finish = embed(3, c // div // 2)
self.e_effect = embed(256, c // div * 2)
self.e_phase = embed(4, c // div // 2)
self.e_position = embed(9, c // div)
self.e_number = embed(13, c // div // 2)
self.e_place = embed(31, c // div)
self.e_attrib = embed(10, c // div // 2)
def __call__(self, x):
x_a_msg = self.e_msg(x[:, :, 0])
x_a_act = self.e_act(x[:, :, 1])
x_a_finish = self.e_finish(x[:, :, 2])
x_a_effect = self.e_effect(x[:, :, 3])
x_a_phase = self.e_phase(x[:, :, 4])
x_a_position = self.e_position(x[:, :, 5])
x_a_number = self.e_number(x[:, :, 6])
x_a_place = self.e_place(x[:, :, 7])
x_a_attrib = self.e_attrib(x[:, :, 8])
return [
x_a_msg, x_a_act, x_a_finish, x_a_effect, x_a_phase,
x_a_position, x_a_number, x_a_place, x_a_attrib]
class CardEncoder(nnx.Module):
def __init__(
self, channels, id_embed_dim, *, version=1,
dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.channels = channels
self.version = version
self.dtype = dtype
self.param_dtype = param_dtype
self.n_bins = 32
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
norm = partial(
nnx.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype, rngs=rngs)
embed = partial(
nnx.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init, rngs=rngs)
fc_embed = partial(
nnx.Linear, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
self.fc_num = mlp(self.n_bins, c // 8, last_lin=False)
self.e_loc = embed(9, c // 16 * 2)
self.e_seq = embed(76, c // 16 * 2)
self.e_owner = embed(2, c // 16)
self.e_position = embed(9, c // 16)
self.e_overley = embed(2, c // 16)
self.e_attribute = embed(8, c // 16)
self.e_race = embed(27, c // 16)
self.e_level = embed(14, c // 16)
self.e_counter = embed(16, c // 16)
self.e_negated = embed(3, c // 16)
self.fc_atk = fc_embed(c // 8, c // 16, kernel_init=default_fc_init1)
self.fc_def = fc_embed(c // 8, c // 16, kernel_init=default_fc_init1)
self.fc_type = fc_embed(25, c // 16 * 2, kernel_init=default_fc_init2)
self.fc_id = mlp(id_embed_dim, c, kernel_init=default_fc_init2)
self.fc_cards = mlp(c, c, kernel_init=default_fc_init2)
self.norm = norm(c)
def num_transform(self, x):
bin_points, bin_intervals = make_bin_params(n_bins=32)
return self.fc_num(bytes_to_bin(x, bin_points, bin_intervals))
def __call__(self, x_id, x, mask):
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(self.dtype)
c_mask = x1[:, :, 0]
c_mask = c_mask.at[:, 0].set(False)
x_loc = self.e_loc(x1[:, :, 0])
x_seq = self.e_seq(x1[:, :, 1])
x_owner = self.e_owner(x1[:, :, 2])
x_position = self.e_position(x1[:, :, 3])
x_overley = self.e_overley(x1[:, :, 4])
x_attribute = self.e_attribute(x1[:, :, 5])
x_race = self.e_race(x1[:, :, 6])
x_level = self.e_level(x1[:, :, 7])
x_counter = self.e_counter(x1[:, :, 8])
x_negated = self.e_negated(x1[:, :, 9])
x_atk = self.num_transform(x2[:, :, 0:2])
x_atk = self.fc_atk(x_atk)
x_def = self.num_transform(x2[:, :, 2:4])
x_def = self.fc_def(x_def)
x_type = self.fc_type(x2[:, :, 4:])
x_id = nnx.swish(self.fc_id(x_id))
feats_g = [
x_id, x_loc, x_seq, x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated, x_atk, x_def, x_type]
if mask is not None:
assert len(feats_g) == mask.shape[-1]
feats = [
jnp.where(mask[..., i:i+1] == 1, f, f[..., -1:, :])
for i, f in enumerate(feats_g)
]
else:
feats = feats_g
x_cards = jnp.concatenate(feats[1:], axis=-1)
x_cards = self.fc_cards(x_cards)
x_cards = x_cards * feats[0]
f_cards = self.norm(x_cards)
return f_cards, c_mask
class GlobalEncoder(nnx.Module):
def __init__(
self, channels, *, version=1, dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.channels = channels
self.version = version
self.dtype = dtype
self.param_dtype = param_dtype
self.n_bins = 32
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
norm = partial(
nnx.LayerNorm, use_scale=True, use_bias=True, dtype=self.dtype, rngs=rngs)
embed = partial(
nnx.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init, rngs=rngs)
fc_embed = partial(
nnx.Linear, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
self.fc_num = mlp(self.n_bins, c // 8, last_lin=False)
self.fc_lp = fc_embed(c // 8, c // 4, kernel_init=default_fc_init2)
self.fc_oppo_lp = fc_embed(c // 8, c // 4, kernel_init=default_fc_init2)
self.e_turn = embed(20, c // 8)
self.e_phase = embed(11, c // 8)
self.e_if_first = embed(2, c // 8)
self.e_is_my_turn = embed(2, c // 8)
self.e_count = embed(100, c // 16)
self.e_hand_count = embed(100, c // 16)
self.norm = norm(c * 2)
self.out_channels = c * 2
def num_transform(self, x):
bin_points, bin_intervals = make_bin_params(n_bins=32)
return self.fc_num(bytes_to_bin(x, bin_points, bin_intervals))
def __call__(self, x):
x1 = x[:, :4].astype(self.dtype)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = self.fc_lp(self.num_transform(x1[:, 0:2]))
x_oppo_lp = self.fc_oppo_lp(self.num_transform(x1[:, 2:4]))
x_turn = self.e_turn(x2[:, 0])
x_phase = self.e_phase(x2[:, 1])
x_if_first = self.e_if_first(x2[:, 2])
x_is_my_turn = self.e_is_my_turn(x2[:, 3])
x_cs = self.e_count(x3).reshape((x.shape[0], -1))
x_my_hand_c = self.e_hand_count(x3[:, 1])
x_op_hand_c = self.e_hand_count(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = self.norm(x)
return x
def create_id_embed(embedding_shape, dtype, param_dtype, rngs):
if embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int):
n_embed, embed_dim = embedding_shape, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
return nnx.Embed(
n_embed, embed_dim, dtype=dtype, param_dtype=param_dtype,
embedding_init=default_embed_init, rngs=rngs)
class Encoder(nnx.Module):
def __init__(
self, channels, out_channels=None, num_layers=2, embedding_shape=None,
*, freeze_id=False, use_history=True, card_mask=False, noam=False,
action_feats=True, version=1, dtype=None, param_dtype=jnp.float32,
rngs: nnx.Rngs):
self.channels = channels
self.out_channels = out_channels
self.num_layers = num_layers
self.freeze_id = freeze_id
self.use_history = use_history
self.card_mask = card_mask
self.noam = noam
self.action_feats = action_feats
self.version = version
key = rngs.params()
c = self.channels
norm = partial(
nnx.LayerNorm, use_scale=True, use_bias=True, dtype=dtype, rngs=rngs)
embed = partial(
nnx.Embed, dtype=dtype, param_dtype=param_dtype,
embedding_init=default_embed_init, rngs=rngs)
fc_layer = partial(
nnx.Linear, use_bias=False, param_dtype=param_dtype, dtype=dtype, rngs=rngs)
self.id_embed = create_id_embed(embedding_shape, dtype, param_dtype, rngs)
embed_dim = self.id_embed.features
self.action_encoder = ActionEncoder(
channels=channels, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
# Cards
self.card_encoder = CardEncoder(
channels=channels, id_embed_dim=embed_dim,
version=version, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
n_heads = max(2, c // 128)
self.g_card_embed = nnx.Param(
jax.random.normal(key, (1, 1, c), param_dtype) * 0.02)
for i in range(num_layers):
layer = EncoderLayer(
c, n_heads, llama=self.noam, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
setattr(self, f'card_layer{i+1}', layer)
self.card_norm = norm(c)
# Global
self.global_encoder = GlobalEncoder(
c, version=version, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
in_channels = self.global_encoder.out_channels
if self.version == 2:
self.fc_global = fc_layer(in_channels, c, rngs=rngs)
self.prenorm_global = norm(c)
self.mlp_global = GLUMlp(
c, c * 2, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
else:
self.mlp_global = MLP(
in_channels, (c * 2, c * 2), dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.fc_global = fc_layer(c * 2, c, rngs=rngs)
self.global_norm = norm(c)
# History actions
self.fc_h_id = fc_layer(embed_dim, c, rngs=rngs)
self.e_h_turn = embed(20, c // 2)
self.e_h_phase = embed(12, c // 2)
self.ha_norm_cat = norm(c * 3)
self.ha_fc = fc_layer(c * 3, c, rngs=rngs)
if self.noam:
self.ha_layer = EncoderLayer(
c, n_heads, llama=True, rope=True, rope_max_len=64,
dtype=dtype, param_dtype=param_dtype, rngs=rngs)
else:
self.ha_pe = PositionalEncoding()
self.ha_layer = EncoderLayer(
c, n_heads, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.ha_norm = norm(c)
# Actions
self.na_card_embed = nnx.Param(
jax.random.normal(key, (1, 1, c), param_dtype) * 0.02)
self.fc_a_id = fc_layer(embed_dim, c, rngs=rngs)
self.norm_a_cat = norm(c * 2)
self.fc_a_cat = fc_layer(c * 2, c, rngs=rngs)
self.fc_a_cards = fc_layer(c, c, rngs=rngs)
self.fc_a = fc_layer(c, c, rngs=rngs)
# State
self.fc_a_g = fc_layer(c, c, rngs=rngs)
oc = self.out_channels or c
if self.version == 2:
self.mlp_state = GLUMlp(
c * 4, c * 2, oc, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
else:
self.mlp_state = MLP(
(c * 4, c * 2, oc), dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.state_norm = norm(oc)
def encode_id(self, x):
x = decode_id(x)
x = self.id_embed(x)
if self.freeze_id:
x = jax.lax.stop_gradient(x)
return x
def concat_token(self, x, token, mask=None):
batch_size = x.shape[0]
token = jnp.tile(token, (batch_size, 1, 1)).astype(x.dtype)
x = jnp.concatenate([token, x], axis=1)
if mask is not None:
mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=mask.dtype), mask], axis=1)
return x, mask
def __call__(self, x):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
mask = x['mask_']
batch_size = x_global.shape[0]
valid = x_global[:, -1] == 0
# Cards
x_id = self.encode_id(x_cards[:, :, :2].astype(jnp.int32))
f_cards, c_mask = self.card_encoder(x_id, x_cards[:, :, 2:], mask)
f_cards, c_mask = self.concat_token(f_cards, self.g_card_embed.value, c_mask if self.card_mask else None)
for i in range(self.num_layers):
f_cards = getattr(self, f'card_layer{i+1}')(
f_cards, src_key_padding_mask=c_mask)
f_cards = self.card_norm(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = self.global_encoder(x_global)
if self.version == 2:
x_global = self.fc_global(x_global)
f_global = x_global + self.mlp_global(self.prenorm_global(x_global))
else:
f_global = x_global + self.mlp_global(x_global)
f_global = self.fc_global(f_global)
f_global = self.global_norm(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 3] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = self.encode_id(x_h_actions[..., 1:3])
x_h_id = self.fc_h_id(x_h_id)
x_h_a_feats = self.action_encoder(x_h_actions[:, :, 3:12])
x_h_a_turn = self.e_h_turn(x_h_actions[:, :, 12])
x_h_a_phase = self.e_h_phase(x_h_actions[:, :, 13])
x_h_a_feats.extend([x_h_id, x_h_a_turn, x_h_a_phase])
x_h_a_feats = jnp.concatenate(x_h_a_feats, axis=-1)
x_h_a_feats = self.ha_norm_cat(x_h_a_feats)
x_h_a_feats = self.ha_fc(x_h_a_feats)
if not self.noam:
x_h_a_feats = self.ha_pe(x_h_a_feats)
f_h_actions = self.ha_layer(x_h_a_feats, src_key_padding_mask=h_mask)
f_g_h_actions = self.ha_norm(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
f_cards = self.concat_token(f_cards[:, 1:], self.na_card_embed.value)[0]
spec_index = x_actions[..., 0]
f_a_cards = f_cards[jnp.arange(batch_size)[:, None], spec_index]
x_a_id = self.encode_id(x_actions[..., 1:3])
x_a_id = self.fc_a_id(x_a_id)
x_a_feats = self.action_encoder(x_actions[..., 3:])
x_a_feats.append(x_a_id)
x_a_feats = jnp.concatenate(x_a_feats, axis=-1)
x_a_feats = self.norm_a_cat(x_a_feats)
x_a_feats = self.fc_a_cat(x_a_feats)
f_a_cards = self.fc_a_cards(f_a_cards)
f_actions = nnx.silu(f_a_cards) * x_a_feats
f_actions = x_a_feats + self.fc_a(f_actions)
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
# State
g_feats = [f_g_card, f_global]
if self.use_history:
g_feats.append(f_g_h_actions)
if self.action_feats:
f_actions_g = self.fc_a_g(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions)
f_state = jnp.concatenate(g_feats, axis=-1)
f_state = self.mlp_state(f_state)
f_state = self.state_norm(f_state)
return f_actions, f_state, a_mask, valid
class Actor(nnx.Module):
def __init__(
self, in_channels, channels, *, dtype=None, param_dtype=jnp.float32,
final_init=nnx.initializers.orthogonal(0.01), rngs: nnx.Rngs):
self.channels = channels
self.dtype = dtype
self.param_dtype = param_dtype
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype,
last_kernel_init=final_init, rngs=rngs)
self.mlp = mlp((in_channels, channels), use_bias=True)
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
f_state = self.mlp(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class FiLMActor(nnx.Module):
def __init__(
self, in_channels, channels, *, noam=False, dtype=None, param_dtype=jnp.float32,
final_init=nnx.initializers.orthogonal(0.01), rngs: nnx.Rngs):
self.channels = channels
self.dtype = dtype
self.param_dtype = param_dtype
c = self.channels
self.fc = nnx.Linear(
in_channels, channels * 4, dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
n_heads = max(2, channels // 128)
self.encoder = EncoderLayer(
channels, n_heads, llama=noam, dtype=self.dtype,
param_dtype=self.param_dtype, rngs=rngs)
self.out = nnx.Linear(
channels, 1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=final_init, rngs=rngs)
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
t = self.fc(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
f_actions = self.encoder(
f_actions, a_s, a_b, o_s, o_b, src_key_padding_mask=mask)
logits = self.out(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nnx.Module):
def __init__(
self, in_channels, channels=(128, 128, 128), *,
dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.channels = channels
self.dtype = dtype
self.param_dtype = param_dtype
self.mlp = MLP(
in_channels, channels, last_lin=False,
dtype=self.dtype, param_dtype=self.param_dtype, rngs=rngs)
final_init = nnx.initializers.orthogonal(1.0)
self.out = nnx.Linear(
channels[-1], 1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=final_init, rngs=rngs)
def __call__(self, f_state):
f_state = f_state.astype(self.dtype)
x = self.mlp(f_state)
x = self.out(x)
return x
class CrossCritic(nnx.Module):
def __init__(
self, in_channels, channels=(128, 128, 128), bn_momentum=0.99,
*, dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.channels = channels
self.dtype = dtype
self.param_dtype = param_dtype
linear = partial(
nnx.Linear, dtype=self.dtype, param_dtype=self.param_dtype,
use_bias=False, rngs=rngs)
BN = partial(
BatchRenorm, dtype=self.dtype, param_dtype=self.param_dtype,
momentum=bn_momentum, axis_name="local_devices")
ic = in_channels
self.bn = BN(ic)
for i, c in enumerate(self.channels):
setattr(self, f'fc{i + 1}', linear(ic, c))
setattr(self, f'bn{i + 1}', BN(c))
ic = c
self.out = nnx.Linear(
ic, 1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nnx.initializers.orthogonal(1.0), rngs=rngs)
def __call__(self, f_state):
x = f_state.astype(self.dtype)
x = self.bn(x)
for i in range(len(self.channels)):
x = getattr(self, f'fc{i + 1}')(x)
x = nnx.relu(x)
x = getattr(self, f'bn{i + 1}')(x)
x = self.out(x)
return x
def rnn_step_by_main(rnn_layer, rstate, f_state, done, main, return_state=False):
if main is not None:
rstate1, rstate2 = rstate
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, f_state = rnn_layer(rstate, f_state)
if main is not None:
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate = rstate1, rstate2
if done is not None:
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
if return_state:
return rstate, (f_state, rstate)
else:
return rstate, f_state
def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True, return_state=False):
if switch:
def scan_fn(carry, cell, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def scan_fn(carry, cell, x, done, main):
return rnn_step_by_main(cell, carry, x, done, main, return_state)
rstate, f_state = nnx.scan(
scan_fn, state_axes={}
)(rstate, rnn_layer, f_state, done, switch_or_main)
return rstate, f_state
class Memory(nnx.Module):
def __init__(
self, in_channels, channels, rnn_type, switch=False,
*, dtype=None, param_dtype=jnp.float32, rngs: nnx.Rngs):
self.in_channels = in_channels
self.channels = channels
self.rnn_type = rnn_type
self.switch = switch
self.dtype = dtype
self.param_dtype = param_dtype
if rnn_type == 'lstm':
self.rnn = OptimizedLSTMCell(
in_channels, channels, dtype=dtype, param_dtype=param_dtype,
kernel_init=nnx.initializers.orthogonal(1.0), rngs=rngs)
elif rnn_type == 'gru':
self.rnn = GRUCell(
in_channels, channels, dtype=dtype, param_dtype=param_dtype,
kernel_init=nnx.initializers.orthogonal(1.0), rngs=rngs)
elif rnn_type == 'rwkv':
raise NotImplementedError
# num_heads = channels // 32
# self.rnn = Rwkv6SelfAttention(
# num_heads, dtype=dtype, param_dtype=param_dtype)
else:
self.rnn = None
def __call__(self, rstate, x, done=None, switch_or_main=None):
if self.rnn is None:
return rstate, x
batch_size = jax.tree.leaves(rstate)[0].shape[0]
num_steps = x.shape[0] // batch_size
multi_step = num_steps > 1
if multi_step:
x, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (x, done, switch_or_main))
rstate, x = rnn_forward_2p(
self.rnn, rstate, x, done, switch_or_main, self.switch, return_state=False)
x = x.reshape((-1, x.shape[-1]))
else:
rstate, x = rnn_step_by_main(
self.rnn, rstate, x, done, switch_or_main, return_state=False)
return rstate, x
def init_state(self, batch_size):
if self.rnn_type == 'lstm':
return (
np.zeros((batch_size, self.channels)),
np.zeros((batch_size, self.channels)),
)
elif self.rnn_type == 'gru':
return np.zeros((batch_size, self.channels))
elif self.rnn_type == 'rwkv':
raise NotImplementedError
# head_size = self.rwkv_head_size
# num_heads = self.channels // self.rwkv_head_size
# return (
# np.zeros((batch_size, num_heads*head_size)),
# np.zeros((batch_size, num_heads*head_size*head_size)),
# )
else:
return None
@dataclass
class EncoderArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
"""the version of the environment and the agent"""
@dataclass
class ModelArgs(EncoderArgs):
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[Literal['lstm', 'gru', 'rwkv', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
rnn_shortcut: bool = False
"""whether to use shortcut for the RNN"""
batch_norm: bool = False
"""whether to use batch normalization for the critic"""
critic_width: int = 128
"""the width of the critic"""
critic_depth: int = 3
"""the depth of the critic"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
class RNNAgent(nnx.Module):
def __init__(
self,
num_layers: int = 2,
num_channels: int = 128,
rnn_channels: int = 512,
use_history: bool = True,
card_mask: bool = False,
rnn_type: str = 'lstm',
film: bool = False,
noam: bool = False,
rwkv_head_size: int = 32,
action_feats: bool = True,
rnn_shortcut: bool = False,
batch_norm: bool = False,
critic_width: int = 128,
critic_depth: int = 3,
version: int = 0,
q_head: bool = False,
switch: bool = True,
freeze_id: bool = False,
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
rngs: nnx.Rngs = None
):
self.rnn_shortcut = rnn_shortcut
self.q_head = q_head
c = num_channels
oc = rnn_channels if rnn_type == 'rwkv' else c
self.encoder = Encoder(
num_channels,
out_channels=oc,
num_layers=num_layers,
embedding_shape=embedding_shape,
freeze_id=freeze_id,
use_history=use_history,
card_mask=card_mask,
noam=noam,
action_feats=action_feats,
version=version,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.memory = Memory(
oc, rnn_channels, rnn_type, switch=switch,
dtype=dtype, param_dtype=param_dtype, rngs=rngs)
ic = rnn_channels + oc if rnn_shortcut else rnn_channels
actor_init = nnx.initializers.orthogonal(1) if self.q_head else nnx.initializers.orthogonal(0.01)
actor_cls = partial(FiLMActor, noam=noam) if film else Actor
self.actor = actor_cls(
ic, c, dtype=jnp.float32, param_dtype=param_dtype, final_init=actor_init, rngs=rngs)
critic_cls = CrossCritic if batch_norm else Critic
cs = [critic_width] * critic_depth
self.critic = critic_cls(
ic, channels=cs, dtype=jnp.float32, param_dtype=param_dtype, rngs=rngs)
def __call__(self, x, rstate, done=None, switch_or_main=None):
f_actions, f_state, mask, valid = self.encoder(x)
rstate, f_state_r = self.memory(rstate, f_state, done, switch_or_main)
if self.rnn_shortcut:
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
logits = self.actor(f_state_r, f_actions, mask)
if self.q_head:
return rstate, logits, valid
value = self.critic(f_state_r)
return rstate, logits, value, valid
def init_rnn_state(self, batch_size):
return self.memory.init_state(batch_size)
\ No newline at end of file
from typing import Optional, Any
import functools
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx.nn.normalization import _compute_stats, _normalize, _canonicalize_axes
from ygoai.rl.jax.modules import make_bin_params, bytes_to_bin, decode_id
default_kernel_init = nnx.initializers.lecun_normal()
default_bias_init = nnx.initializers.zeros
def first_from(*args, error_msg: str):
"""Return the first non-None argument.
If all arguments are None, raise a ValueError with the given error message.
Args:
*args: the arguments to check
error_msg: the error message to raise if all arguments are None
Returns:
The first non-None argument.
"""
for arg in args:
if arg is not None:
return arg
raise ValueError(error_msg)
def act(x, activation):
if activation == 'leaky_relu':
return nnx.leaky_relu(x, negative_slope=0.1)
elif activation == 'relu':
return nnx.relu(x)
elif activation == 'swich' or activation == 'silu':
return nnx.swish(x)
elif activation == 'gelu':
return nnx.gelu(x, approximate=False)
elif activation == "gelu_new":
return nnx.gelu(x, approximate=True)
else:
raise ValueError(f'Unknown activation: {activation}')
class MLP(nnx.Module):
def __init__(
self, in_channels, channels, *, last_lin=True,
activation='leaky_relu', use_bias=False, dtype=None, param_dtype=jnp.float32,
kernel_init=default_kernel_init, bias_init=default_bias_init,
last_kernel_init=default_kernel_init, rngs: nnx.Rngs):
if isinstance(channels, int):
channels = [channels]
self.in_channels = in_channels
self.channels = channels
self.last_lin = last_lin
self.activation = activation
self.n_layers = len(channels)
ic = in_channels
for i, c in enumerate(channels):
if i == len(channels) - 1 and last_lin:
l_kernel_init = last_kernel_init
else:
l_kernel_init = kernel_init
layer = nnx.Linear(
ic, c, dtype=dtype, param_dtype=param_dtype,
kernel_init=l_kernel_init, bias_init=bias_init,
use_bias=use_bias, rngs=rngs)
ic = c
setattr(self, f'fc{i+1}', layer)
def __call__(self, x):
for i in range(self.n_layers):
x = getattr(self, f'fc{i+1}')(x)
if i < self.n_layers - 1 or not self.last_lin:
x = act(x, self.activation)
return x
class GLUMlp(nnx.Module):
def __init__(
self, in_channels, channels, out_channels=None,
*, use_bias=False, dtype=None, param_dtype=jnp.float32,
kernel_init=default_kernel_init, bias_init=default_bias_init,
rngs: nnx.Rngs):
self.in_channels = in_channels
self.channels = channels or 2 * in_channels
self.out_channels = out_channels or in_channels
linear = functools.partial(
nnx.Linear,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=kernel_init,
bias_init=bias_init,
use_bias=use_bias,
rngs=rngs,
)
self.gate = linear(self.in_channels, self.channels)
self.up = linear(self.in_channels, self.channels)
self.down = linear(self.channels, self.out_channels)
def __call__(self, x):
g = self.gate(x)
x = nnx.silu(g) * self.up(x)
x = self.down(x)
return x
class BatchRenorm(nnx.Module):
"""BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
and adapted from Flax's BatchNorm implementation:
https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228
Attributes:
use_running_average: if True, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of
the batch statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
def __init__(
self,
num_features: int,
*,
use_running_average: bool = False,
axis: int = -1,
momentum: float = 0.99,
epsilon: float = 1e-5,
dtype: Optional[jnp.dtype] = None,
param_dtype: jnp.dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: nnx.initializers.Initializer = nnx.initializers.zeros_init(),
scale_init: nnx.initializers.Initializer = nnx.initializers.ones_init(),
axis_name: Optional[str] = None,
axis_index_groups: Any = None,
use_fast_variance: bool = True,
rngs: nnx.Rngs,
):
feature_shape = (num_features,)
self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32))
self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32))
self.steps = nnx.BatchStat(jnp.zeros((), jnp.int64))
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)
self.num_features = num_features
self.use_running_average = use_running_average
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self.dtype = dtype
self.param_dtype = param_dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.use_fast_variance = use_fast_variance
def __call__(
self,
x,
use_running_average: Optional[bool] = None,
):
"""Normalizes the input using batch statistics.
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = first_from(
use_running_average,
self.use_running_average,
error_msg="""
No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
if use_running_average:
mean, var = self.mean.value, self.var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
# The code below is implemented following the Batch Renormalization paper
ra_mean = self.mean.value
ra_var = self.var.value
steps = self.steps.value
r_max = 3
d_max = 5
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max, r_max)
d = jax.lax.stop_gradient((mean - ra_mean) / ra_std)
d = jnp.clip(d, -d_max, d_max)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r
# Warm up batch renorm for 100_000 steps to build up proper running statistics
warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32)
custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var
custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean
self.mean.value = (
self.momentum * ra_mean + (1 - self.momentum) * mean
)
self.var.value = (
self.momentum * ra_var + (1 - self.momentum) * var
)
self.steps.value = steps + 1
return _normalize(
x,
custom_mean,
custom_var,
self.scale.value,
self.bias.value,
reduction_axes,
feature_axes,
self.dtype,
self.epsilon,
)
import jax
import jax.numpy as jnp
from flax import nnx
default_kernel_init = nnx.initializers.lecun_normal()
default_bias_init = nnx.initializers.zeros_init()
class OptimizedLSTMCell(nnx.Module):
def __init__(
self, in_features, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 4 * features,
use_bias=False, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 4 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
c, h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
i, f, g, o = jnp.split(dense_i + dense_h, indices_or_sections=4, axis=-1)
i, f, g, o = self.gate_fn(i), self.gate_fn(f), self.activation_fn(g), self.gate_fn(o)
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
class GRUCell(nnx.Module):
def __init__(
self, in_features: int, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 3 * features,
use_bias=True, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 3 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
ir, iz, in_ = jnp.split(dense_i, indices_or_sections=3, axis=-1)
hr, hz, hn = jnp.split(dense_h, indices_or_sections=3, axis=-1)
r = self.gate_fn(ir + hr)
z = self.gate_fn(iz + hz)
n = self.activation_fn(in_ + r * hn)
new_h = (1.0 - z) * n + z * h
return new_h, new_h
import functools
import numpy as np
import jax.numpy as jnp
from flax import nnx
from ygoai.rl.jax.nnx.modules import default_kernel_init, default_bias_init, act, GLUMlp, first_from
def precompute_freqs_cis(
dim: int, end: int, theta=10000.0, dtype=jnp.float32
):
# returns:
# cos, sin: (end, dim)
freqs = 1.0 / \
(theta ** (np.arange(0, dim, 2, dtype=np.float32)[: (dim // 2)] / dim))
t = np.arange(end, dtype=np.float32) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
freqs = np.concatenate((freqs, freqs), axis=-1)
cos, sin = np.cos(freqs), np.sin(freqs)
return jnp.array(cos, dtype=dtype), jnp.array(sin, dtype=dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return jnp.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id=None):
# inputs:
# x: (batch_size, seq_len, num_heads, head_dim)
# cos, sin: (seq_len, head_dim)
# position_id: (batch_size, seq_len)
# returns:
# x: (batch_size, seq_len, num_heads, head_dim)
if position_id is None:
q_pos = jnp.arange(q.shape[1])[None, :]
k_pos = jnp.arange(k.shape[1])[None, :]
else:
q_pos = position_id
k_pos = position_id
cos_q = jnp.take(cos, q_pos, axis=0)[:, :, None, :]
sin_q = jnp.take(sin, q_pos, axis=0)[:, :, None, :]
q = (q * cos_q) + (rotate_half(q) * sin_q)
cos_k = jnp.take(cos, k_pos, axis=0)[:, :, None, :]
sin_k = jnp.take(sin, k_pos, axis=0)[:, :, None, :]
k = (k * cos_k) + (rotate_half(k) * sin_k)
return q, k
def make_apply_rope(head_dim, max_len, dtype):
cos, sin = precompute_freqs_cis(
dim=head_dim, end=max_len, dtype=dtype)
def add_pos(q, k, p=None): return apply_rotary_pos_emb_index(
q, k, cos, sin, p)
return add_pos
# from nnx.MultiHeadAttention
class MultiHeadAttention(nnx.Module):
"""Multi-head attention.
Example usage::
>>> import flax.linen as nn
>>> import jax
>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)
>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
out_kernel_init: optional initializer for the kernel of the output Dense layer,
if None, the kernel_init is used.
bias_init: initializer for the bias of the Dense layers.
out_bias_init: optional initializer for the bias of the output Dense layer,
if None, the bias_init is used.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
"""
def __init__(
self,
num_heads: int,
in_features: int,
qkv_features=None,
out_features=None,
*,
dtype=None,
param_dtype=jnp.float32,
rope=False,
max_len=2048,
broadcast_dropout=True,
dropout_rate=0.0,
deterministic=None,
precision=None,
kernel_init=default_kernel_init,
out_kernel_init=None,
bias_init=default_bias_init,
out_bias_init=None,
use_bias: bool = True,
attention_fn=nnx.dot_product_attention,
# Deprecated, will be removed.
qkv_dot_general=None,
out_dot_general=None,
qkv_dot_general_cls=None,
out_dot_general_cls=None,
rngs: nnx.Rngs,
):
self.num_heads = num_heads
self.in_features = in_features
self.qkv_features = (
qkv_features if qkv_features is not None else in_features
)
self.out_features = (
out_features if out_features is not None else in_features
)
self.dtype = dtype
self.param_dtype = param_dtype
self.rope = rope
self.max_len = max_len
self.broadcast_dropout = broadcast_dropout
self.dropout_rate = dropout_rate
self.deterministic = deterministic
self.precision = precision
self.kernel_init = kernel_init
self.out_kernel_init = out_kernel_init
self.bias_init = bias_init
self.out_bias_init = out_bias_init
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qkv_dot_general = qkv_dot_general
self.out_dot_general = out_dot_general
self.qkv_dot_general_cls = qkv_dot_general_cls
self.out_dot_general_cls = out_dot_general_cls
if self.qkv_features % self.num_heads != 0:
raise ValueError(
f'Memory dimension ({self.qkv_features}) must be divisible by '
f"'num_heads' heads ({self.num_heads})."
)
self.head_dim = self.qkv_features // self.num_heads
linear_general = functools.partial(
nnx.LinearGeneral,
in_features=self.in_features,
out_features=(self.num_heads, self.head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
precision=self.precision,
dot_general=self.qkv_dot_general,
dot_general_cls=self.qkv_dot_general_cls,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
self.query = linear_general(rngs=rngs)
self.key = linear_general(rngs=rngs)
self.value = linear_general(rngs=rngs)
self.out = nnx.LinearGeneral(
in_features=(self.num_heads, self.head_dim),
out_features=self.out_features,
axis=(-2, -1),
kernel_init=self.out_kernel_init or self.kernel_init,
bias_init=self.out_bias_init or self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
dot_general=self.out_dot_general,
dot_general_cls=self.out_dot_general_cls,
rngs=rngs,
)
self.rngs = rngs if dropout_rate > 0.0 else None
def __call__(
self,
inputs_q,
inputs_k=None,
inputs_v=None,
*,
mask=None,
deterministic=None,
rngs=None,
sow_weights=False,
):
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of
inputs_q (self attention).
If only inputs_v is None, it will copy the value of inputs_k.
Args:
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
inputs_k will copy the value of inputs_q.
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
inputs_v will copy the value of inputs_k.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
rngs: container for random number generators to generate the dropout
mask when `deterministic` is False. The `rngs` container should have a
`dropout` key.
sow_weights: if ``True``, the attention weights are sowed into the
'intermediates' collection.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
if rngs is None:
rngs = self.rngs
if inputs_k is None:
if inputs_v is not None:
raise ValueError(
'`inputs_k` cannot be None if `inputs_v` is not None. '
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
'value to `inputs_k` and leave `inputs_v` as None.'
)
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k
if inputs_q.shape[-1] != self.in_features:
raise ValueError(
f'Incompatible input dimension, got {inputs_q.shape[-1]} '
f'but module expects {self.in_features}.'
)
query = self.query(inputs_q)
key = self.key(inputs_k)
value = self.value(inputs_v)
if self.rope:
add_pos = make_apply_rope(
self.head_dim, self.max_len, self.dtype)
else:
def add_pos(q, k, p=None): return (q, k)
query, key = add_pos(query, key)
if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout.
deterministic = first_from(
deterministic,
self.deterministic,
error_msg="""
No `deterministic` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
if not deterministic:
if rngs is None:
raise ValueError(
"'rngs' must be provided if 'dropout_rng' is not given."
)
dropout_rng = rngs.dropout()
else:
dropout_rng = None
else:
deterministic = True
dropout_rng = None
# apply attention
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=deterministic,
dtype=self.dtype,
precision=self.precision,
module=self if sow_weights else None,
)
# back to the original inputs dimensions
out = self.out(x)
return out
def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
"""1D Sinusoidal Position Embedding Initializer.
Args:
max_len: maximum possible length for the input.
min_scale: float: minimum frequency-scale in sine grating.
max_scale: float: maximum frequency-scale in sine grating.
Returns:
output: init function returning `(1, max_len, d_feature)`
"""
def init(key, shape, dtype=np.float32):
"""Sinusoidal init."""
del key, dtype
d_feature = shape[-1]
pe = np.zeros((max_len, d_feature), dtype=np.float32)
position = np.arange(0, max_len)[:, np.newaxis]
scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
div_term = min_scale * \
np.exp(np.arange(0, d_feature // 2) * scale_factor)
pe[:, : d_feature // 2] = np.sin(position * div_term)
pe[:, d_feature // 2: 2 * (d_feature // 2)] = np.cos(position * div_term)
pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
return jnp.array(pe)
return init
class PositionalEncoding(nnx.Module):
"""Adds (optionally learned) positional embeddings to the inputs.
"""
def __init__(
self, in_channels, *, max_len=512, learned=False,
initializer=sinusoidal_init, rngs: nnx.Rngs):
self.pos_emb_shape = (1, max_len, in_channels)
self.max_len = max_len
self.learned = learned
init = initializer(max_len=max_len)(None, self.pos_emb_shape)
if learned:
self.pos_embedding = nnx.Param(init)
else:
self.pos_embedding = None
def __call__(self, x):
assert x.ndim == 3, (
'Number of dimensions should be 3, but it is: %d' % x.ndim
)
length = x.shape[1]
if self.pos_embedding is None:
pos_embedding = sinusoidal_init(max_len=self.max_len)(
None, self.pos_emb_shape
)
else:
pos_embedding = self.pos_embedding.value
return x + pos_embedding[:, :length, :]
class MlpBlock(nnx.Module):
def __init__(
self, in_channels, channels=None, out_channels=None,
*, activation="gelu", use_bias=False, dtype=jnp.float32,
param_dtype=jnp.float32, kernel_init=default_kernel_init,
bias_init=default_bias_init, rngs: nnx.Rngs):
self.in_channels = in_channels
self.channels = channels or 4 * in_channels
self.out_channels = out_channels or in_channels
self.activation = activation
linear = functools.partial(
nnx.Linear,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=kernel_init,
bias_init=bias_init,
use_bias=use_bias,
rngs=rngs,
)
self.fc1 = linear(self.in_channels, self.channels)
self.fc2 = linear(self.channels, self.out_channels)
def __call__(self, x):
x = self.fc1(x)
x = act(x, self.activation)
return x
class EncoderLayer(nnx.Module):
def __init__(
self, d_model, n_heads, dim_feedforward=None,
*, llama=False, activation="relu", rope=False, rope_max_len=2048,
attn_pdrop=0.0, resid_pdrop=0.0, layer_norm_epsilon=1e-6,
dtype=None, param_dtype=jnp.float32, kernel_init=default_kernel_init,
bias_init=default_bias_init, use_bias=False, rngs: nnx.Rngs
):
if not llama and rope:
raise ValueError("RoPE can only be used with llama=True")
self.d_model = d_model
self.n_heads = n_heads
self.dim_feedforward = dim_feedforward
self.dtype = dtype
if llama:
norm = nnx.RMSNorm
mlp = GLUMlp
else:
norm = nnx.LayerNorm
mlp = functools.partial(MlpBlock, activation=activation)
self.ln1 = norm(
d_model, epsilon=layer_norm_epsilon, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.attn = MultiHeadAttention(
n_heads, d_model, rope=rope, max_len=rope_max_len,
dtype=dtype, param_dtype=param_dtype,
use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init,
dropout_rate=attn_pdrop, rngs=rngs)
self.dropout1 = nnx.Dropout(rate=resid_pdrop, rngs=rngs if resid_pdrop > 0.0 else None)
self.ln2 = norm(
d_model, epsilon=layer_norm_epsilon, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.mlp = mlp(
d_model, channels=dim_feedforward, dtype=dtype, param_dtype=param_dtype,
use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init, rngs=rngs)
self.dropout2 = nnx.Dropout(rate=resid_pdrop, rngs=rngs if resid_pdrop > 0.0 else None)
def __call__(
self, x, attn_scale=None, attn_bias=None,
output_scale=None, output_bias=None, *, src_key_padding_mask=None):
x = jnp.asarray(x, self.dtype)
y = self.ln1(x)
if src_key_padding_mask is None:
mask = None
else:
mask = ~src_key_padding_mask[:, None, None, :]
y = self.attn(y, y, y, mask=mask)
y = self.dropout1(y)
if attn_scale is not None:
y = y * attn_scale
if attn_bias is not None:
y = y + attn_bias
x = x + y
y = self.ln2(x)
y = self.mlp(y)
y = self.dropout2(y)
if output_scale is not None:
y = y * output_scale
if output_bias is not None:
y = y + output_bias
x = x + y
return x
class DecoderLayer(nnx.Module):
def __init__(
self, d_model, n_heads, dim_feedforward=None,
*, llama=False, activation="relu", attn_pdrop=0.0, resid_pdrop=0.0,
layer_norm_epsilon=1e-6, dtype=None, param_dtype=jnp.float32,
kernel_init=default_kernel_init, bias_init=default_bias_init,
use_bias=False, rngs: nnx.Rngs
):
self.d_model = d_model
self.n_heads = n_heads
self.dim_feedforward = dim_feedforward
self.dtype = dtype
if llama:
norm = nnx.RMSNorm
mlp = GLUMlp
else:
norm = nnx.LayerNorm
mlp = functools.partial(MlpBlock, activation=activation)
self.ln1 = norm(
d_model, epsilon=layer_norm_epsilon, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.self_attn = nnx.MultiHeadAttention(
n_heads, d_model, dtype=dtype, param_dtype=param_dtype,
use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init,
dropout_rate=attn_pdrop, rngs=rngs)
self.dropout1 = nnx.Dropout(rate=resid_pdrop, rngs=rngs if resid_pdrop > 0.0 else None)
self.ln2 = norm(
d_model, epsilon=layer_norm_epsilon, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.cross_attn = nnx.MultiHeadAttention(
n_heads, d_model, dtype=dtype, param_dtype=param_dtype,
use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init,
dropout_rate=attn_pdrop, rngs=rngs)
self.dropout2 = nnx.Dropout(rate=resid_pdrop, rngs=rngs if resid_pdrop > 0.0 else None)
self.ln3 = norm(
d_model, epsilon=layer_norm_epsilon, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
self.mlp = mlp(
d_model, channels=dim_feedforward, dtype=dtype, param_dtype=param_dtype,
use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init, rngs=rngs)
self.dropout3 = nnx.Dropout(rate=resid_pdrop, rngs=rngs if resid_pdrop > 0.0 else None)
def __call__(
self, tgt, memory, *, tgt_key_padding_mask=None, memory_key_padding_mask=None):
y = self.ln1(tgt)
if tgt_key_padding_mask is None:
mask = None
else:
mask = ~tgt_key_padding_mask[:, None, None, :]
y = self.self_attn(y, y, y, mask=mask)
y = self.dropout1(y)
x = y + tgt
y = self.ln2(x)
if memory_key_padding_mask is None:
mask = None
else:
mask = ~memory_key_padding_mask[:, None, None, :]
y = self.cross_attn(y, memory, memory, mask=mask)
y = self.dropout2(y)
x = y + x
y = self.ln3(x)
y = self.mlp(y)
y = self.dropout3(y)
x = x + y
return x
...@@ -1528,7 +1528,7 @@ public: ...@@ -1528,7 +1528,7 @@ public:
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16), "max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false), "record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600), "greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false)); "oppo_info"_.Bind(false), "max_steps"_.Bind(1000));
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
...@@ -1629,6 +1629,7 @@ protected: ...@@ -1629,6 +1629,7 @@ protected:
std::uniform_int_distribution<uint64_t> dist_int_; std::uniform_int_distribution<uint64_t> dist_int_;
bool done_{true}; bool done_{true};
long step_count_{0};
bool duel_started_{false}; bool duel_started_{false};
uint32_t eng_flag_{0}; uint32_t eng_flag_{0};
...@@ -1947,6 +1948,7 @@ public: ...@@ -1947,6 +1948,7 @@ public:
discard_hand_ = false; discard_hand_ = false;
done_ = false; done_ = false;
step_count_ = 0;
// update_time_stat(_start, reset_time_count_, reset_time_2_); // update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock(); // _start = clock();
...@@ -2227,6 +2229,14 @@ public: ...@@ -2227,6 +2229,14 @@ public:
next(); next();
} }
step_count_++;
if (!done_ && (step_count_ >= spec_.config["max_steps"_])) {
PlayerId winner = lp_[0] > lp_[1] ? 0 : 1;
_duel_end(winner, 0x01);
done_ = true;
legal_actions_.clear();
}
float reward = 0; float reward = 0;
int reason = 0; int reason = 0;
if (done_) { if (done_) {
...@@ -2334,6 +2344,9 @@ public: ...@@ -2334,6 +2344,9 @@ public:
if (n_options == 0) { if (n_options == 0) {
state["info:num_options"_] = 1; state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1); state["obs:global_"_][22] = uint8_t(1);
// if (step_count_ >= spec_.config["max_steps"_]) {
// fmt::println("Max steps reached return");
// }
return; return;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment