Commit 4f2ad15b authored by sbl1996@126.com's avatar sbl1996@126.com

Add cleanba PPO for TPU

parent 50353ff4
...@@ -135,6 +135,8 @@ class Args: ...@@ -135,6 +135,8 @@ class Args:
"""the number of iterations (computed in runtime)""" """the number of iterations (computed in runtime)"""
world_size: int = 0 world_size: int = 0
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'): def make_env(args, num_envs, num_threads, mode='self'):
...@@ -148,7 +150,7 @@ def make_env(args, num_envs, num_threads, mode='self'): ...@@ -148,7 +150,7 @@ def make_env(args, num_envs, num_threads, mode='self'):
deck2=args.deck2, deck2=args.deck2,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode='self', play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
......
...@@ -221,13 +221,8 @@ def actor( ...@@ -221,13 +221,8 @@ def actor(
return logits, value return logits, value
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
predict_step = torch.compile(predict_step, mode=args.compile) predict_step = torch.compile(predict_step, mode=args.compile)
agent_r = agent agent_r = agent
# example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
# with torch.no_grad():
# agent_r = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
else: else:
agent_r = agent agent_r = agent
......
import os
import queue
import random
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"the name of this experiment"
seed: int = 1
"seed of the experiment"
save_model: bool = False
"whether to save model into the `runs/{run_name}` folder"
log_frequency: int = 2
"the logging frequency of the model performance (in terms of `updates`)"
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 50000000
"total timesteps of the experiments"
learning_rate: float = 1e-3
"the learning rate of the optimizer"
local_num_envs: int = 128
"the number of parallel game environments"
local_env_threads: Optional[int] = None
"the number of threads to use for environment"
num_actor_threads: int = 2
"the number of actor threads to use"
num_steps: int = 128
"the number of steps to run in each environment per policy rollout"
anneal_lr: bool = True
"Toggle learning rate annealing for policy and value networks"
gamma: float = 1.0
"the discount factor gamma"
gae_lambda: float = 0.98
"the lambda for the general advantage estimation"
num_minibatches: int = 8
"the number of mini-batches"
gradient_accumulation_steps: int = 1
"the number of gradient accumulation steps before performing an optimization step"
update_epochs: int = 2
"the K epochs to update the policy"
norm_adv: bool = False
"Toggles advantages normalization"
clip_coef: float = 0.2
"the surrogate clipping coefficient"
ent_coef: float = 0.01
"coefficient of the entropy"
vf_coef: float = 0.5
"coefficient of the value function"
max_grad_norm: float = 1.0
"the maximum norm for the gradient clipping"
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0])
"the device ids that actor workers will use"
learner_device_ids: List[int] = field(default_factory=lambda: [1])
"the device ids that learner workers will use"
distributed: bool = False
"whether to use `jax.distirbuted`"
concurrency: bool = True
"whether to run the actor and learner concurrently"
bfloat16: bool = False
"whether to use bfloat16 for the agent"
thread_affinity: bool = False
"whether to use thread affinity for the environment"
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
values: list
rewards: list
learns: list
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_card_layers=args.num_layers,
num_action_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def apply_fn(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
):
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
def get_action(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
@jax.jit
def get_action_and_value(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
key: jax.random.PRNGKey,
):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs)
logits, value = apply_fn(params, next_obs)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
return next_obs, action, logprob, value.squeeze(), key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
next_value1 = next_value2 = 0
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
storage = []
for _ in range(0, args.num_steps):
global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
learn = next_to_play == ai_player1
inference_time_start = time.time()
cached_next_obs, action, logprob, value, key = get_action_and_value(params, cached_next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
next_nonterminal = 1 - next_done.astype(np.float32)
next_value1 = np.where(learn, value, next_value1) * next_nonterminal
next_value2 = np.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logprobs=logprob,
values=value,
rewards=next_reward,
learns=learn,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_learn = ai_player1 == next_to_play
sharded_data = jax.tree_map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_value1, next_value2, next_learn))
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
SPS = int((global_step - warmup_step) / (time.time() - start_time))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_episodic_return={avg_episodic_return}, rollout_time={np.mean(rollout_time)}"
)
print("SPS:", SPS)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", np.mean(envs.returned_episode_lengths), global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar(
"charts/SPS_update",
int(
args.local_num_envs
* args.num_steps
* len_actor_device_ids
* args.num_actor_threads
* args.world_size
/ (time.time() - update_time_start)
),
global_step,
)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(
args, args.seed, args.local_num_envs, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args)
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros_like(x)]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=args.gradient_accumulation_steps,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict,
obs: np.ndarray,
actions: np.ndarray,
):
logits, value, valid = create_agent(args).apply(params, obs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
p_log_p = logits * jax.nn.softmax(logits)
entropy = -p_log_p.sum(-1)
return logprob, entropy, value.squeeze(), valid
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda)
@jax.jit
def compute_gae(
agent_state: TrainState,
next_obs: np.ndarray,
next_done: np.ndarray,
next_value1: np.ndarray,
next_value2: np.ndarray,
next_learn: np.ndarray,
storage: Transition,
):
next_value = create_agent(args).apply(agent_state.params, next_obs)[1].squeeze()
next_value1 = jnp.where(next_learn, next_value, next_value1)
next_value2 = jnp.where(next_learn, next_value2, next_value)
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
compute_gae_once, carry, (dones[1:], storage.values, storage.rewards, storage.learns), reverse=True
)
return advantages
def ppo_loss(params, obs, actions, behavior_logprobs, advantages, target_values):
newlogprob, entropy, newvalue, valid = get_logprob_entropy_value(params, obs, actions)
logratio = newlogprob - behavior_logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_done: List,
sharded_next_value1: List,
sharded_next_value2: List,
sharded_next_learn: List,
key: jax.random.PRNGKey,
):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree_map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_done, next_value1, next_value2, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_value1, sharded_next_value2, sharded_next_learn]
]
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
advantages = compute_gae(
agent_state, next_obs, next_done, next_value1, next_value2, next_learn, storage)
target_values = advantages + storage.values
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def flatten(x):
return x.reshape((-1,) + x.shape[2:])
def convert_data(x: jnp.ndarray):
x = jax.random.permutation(subkey, x)
x = jnp.reshape(x, (args.num_minibatches * args.gradient_accumulation_steps, -1) + x.shape[1:])
return x
flatten_storage = jax.tree_map(flatten, storage)
flatten_advantages = flatten(advantages)
flatten_target_values = flatten(target_values)
shuffled_storage = jax.tree_map(convert_data, flatten_storage)
shuffled_advantages = convert_data(flatten_advantages)
shuffled_target_values = convert_data(flatten_target_values)
def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_behavior_logprobs, mb_advantages, mb_target_values = minibatch
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params,
mb_obs,
mb_actions,
mb_behavior_logprobs,
mb_advantages,
mb_target_values,
)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_storage.obs,
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_advantages,
shuffled_target_values,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
sharded_next_obss = []
sharded_next_dones = []
sharded_next_values1 = []
sharded_next_values2 = []
sharded_next_learns = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
sharded_next_obs,
sharded_next_done,
sharded_next_value1,
sharded_next_value2,
sharded_next_learn,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
sharded_next_obss.append(sharded_next_obs)
sharded_next_dones.append(sharded_next_done)
sharded_next_values1.append(sharded_next_value1)
sharded_next_values2.append(sharded_next_value2)
sharded_next_learns.append(sharded_next_learn)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
sharded_next_obss,
sharded_next_dones,
sharded_next_values1,
sharded_next_values2,
sharded_next_learns,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if learner_policy_version >= args.num_updates:
break
envs.close()
writer.close()
\ No newline at end of file
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns
)
self.returned_episode_lengths = np.where(
dones, self.episode_lengths, self.returned_episode_lengths
)
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
\ No newline at end of file
from typing import Tuple, Union, Optional
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
def decode_id(x):
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
default_embed_init = nn.initializers.uniform(scale=0.0001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.0001)
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=self.kernel_init, use_bias=False)(x)
if i < n - 1 or not self.last_lin:
x = nn.relu(x)
return x
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1)
class Encoder(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = MLP((c // 8,), last_lin=False, dtype=self.dtype, param_dtype=self.param_dtype)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32)
x_cards_2 = x_cards[:, :, 12:].astype(self.dtype or jnp.float32)
x_id = decode_id(x_cards_1[:, :, :2])
x_id = id_embed(x_id)
x_id = MLP(
(c, c // 4), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x_cards_1[:, :, 3]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x_cards_1[:, :, 4])
x_position = embed(9, c // 16)(x_cards_1[:, :, 5])
x_overley = embed(2, c // 16)(x_cards_1[:, :, 6])
x_attribute = embed(8, c // 16)(x_cards_1[:, :, 7])
x_race = embed(27, c // 16)(x_cards_1[:, :, 8])
x_level = embed(14, c // 16)(x_cards_1[:, :, 9])
x_counter = embed(16, c // 16)(x_cards_1[:, :, 10])
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11])
x_atk = num_transform(x_cards_2[:, :, 0:2])
x_atk = fc_layer(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4])
x_def = fc_layer(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_layer(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
x_feat = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_feat = layer_norm()(x_feat)
f_cards = jnp.concatenate([x_id, x_feat], axis=-1)
f_cards = f_cards + f_loc + f_seq
num_heads = max(2, c // 128)
for _ in range(self.num_card_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1))
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(self.dtype or jnp.float32)
x_g_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0])
x_g_phase = embed(11, c // 8)(x_global_2[:, 1])
x_g_if_first = embed(2, c // 8)(x_global_2[:, 2])
x_g_is_my_turn = embed(2, c // 8)(x_global_2[:, 3])
x_global_3 = x_global[:, 8:22].astype(jnp.int32)
x_g_cs = count_embed(x_global_3).reshape((batch_size, -1))
x_g_my_hand_c = hand_count_embed(x_global_3[:, 1])
x_g_op_hand_c = hand_count_embed(x_global_3[:, 8])
x_global = jnp.concatenate([
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
f_cards = f_cards + jnp.expand_dims(f_global, 1)
x_actions = x_actions.astype(jnp.int32)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = f_a_cards + fc_layer(c)(layer_norm()(f_a_cards))
x_a_feats = action_encoder(x_actions[..., 2:])
f_actions = f_a_cards + layer_norm()(x_a_feats)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, f_cards,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_'].astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
f_h_actions = layer_norm()(x_h_id) + layer_norm()(x_h_a_feats)
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_action_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, f_h_actions,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask)
f_actions = layer_norm()(f_actions)
f_s_cards_global = f_cards.mean(axis=1)
c_mask = 1 - a_mask[:, :, None].astype(f_actions.dtype)
f_s_actions_ha = (f_actions * c_mask).sum(axis=1) / c_mask.sum(axis=1)
f_state = jnp.concatenate([f_s_cards_global, f_s_actions_ha], axis=-1)
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_actions, mask):
c = self.channels
num_heads = max(2, c // 128)
f_actions = EncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = MLP((c // 4, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_actions)
logits = logits[..., 0].astype(jnp.float32)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
c = self.channels
x = MLP((c // 2, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
x = x.astype(jnp.float32)
return x
class PPOAgent(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
encoder = Encoder(
channels=self.channels,
num_card_layers=self.num_card_layers,
num_action_layers=self.num_action_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
critic = Critic(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_actions, mask)
value = critic(f_state)
return logits, value, valid
import numpy as np
def evaluate(envs, act_fn, params):
num_episodes = envs.num_envs
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
while True:
actions = act_fn(params, obs)
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if not d:
continue
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
import functools
from typing import Callable, Optional, Sequence, Union, Dict, Any
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from flax.linen.dtypes import promote_dtype
Array = Union[jax.Array, Any]
PRNGKey = jax.Array
RNGSequences = Dict[str, PRNGKey]
Dtype = Union[jax.typing.DTypeLike, Any]
Shape = Sequence[int]
PrecisionLike = Union[jax.lax.Precision, str]
default_kernel_init = nn.initializers.lecun_normal()
default_bias_init = nn.initializers.zeros
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
"""1D Sinusoidal Position Embedding Initializer.
Args:
max_len: maximum possible length for the input.
min_scale: float: minimum frequency-scale in sine grating.
max_scale: float: maximum frequency-scale in sine grating.
Returns:
output: init function returning `(1, max_len, d_feature)`
"""
def init(key, shape, dtype=np.float32):
"""Sinusoidal init."""
del key, dtype
d_feature = shape[-1]
pe = np.zeros((max_len, d_feature), dtype=np.float32)
position = np.arange(0, max_len)[:, np.newaxis]
scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
div_term = min_scale * \
np.exp(np.arange(0, d_feature // 2) * scale_factor)
pe[:, : d_feature // 2] = np.sin(position * div_term)
pe[:, d_feature // 2: 2 * (d_feature // 2)
] = np.cos(position * div_term)
pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
return jnp.array(pe)
return init
class PositionalEncoding(nn.Module):
"""Adds (optionally learned) positional embeddings to the inputs.
"""
max_len: int = 512
learned: bool = False
@nn.compact
def __call__(self, inputs):
"""Applies AddPositionEmbs module.
By default this layer uses a fixed sinusoidal embedding table. If a
learned position embedding is desired, pass an initializer to
posemb_init in the configuration.
Args:
inputs: input data.
Returns:
output: `(bs, timesteps, in_dim)`
"""
# inputs.shape is (batch_size, seq_len, emb_dim)
assert inputs.ndim == 3, (
'Number of dimensions should be 3, but it is: %d' % inputs.ndim
)
length = inputs.shape[1]
pos_emb_shape = (1, self.max_len, inputs.shape[-1])
initializer = sinusoidal_init(max_len=self.max_len)
if self.learned:
pos_embedding = self.param(
'pos_embedding', initializer, pos_emb_shape
)
else:
pos_embedding = initializer(
None, pos_emb_shape, None
)
pe = pos_embedding[:, :length, :]
return inputs + pe
def precompute_freqs_cis(
dim: int, end: int, theta=10000.0, dtype=jnp.float32
):
# returns:
# cos, sin: (end, dim)
freqs = 1.0 / \
(theta ** (np.arange(0, dim, 2, dtype=np.float32)[: (dim // 2)] / dim))
t = np.arange(end, dtype=np.float32) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
freqs = np.concatenate((freqs, freqs), axis=-1)
cos, sin = np.cos(freqs), np.sin(freqs)
return jnp.array(cos, dtype=dtype), jnp.array(sin, dtype=dtype)
# from chatglm2, different from original rope
def precompute_freqs_cis2(
dim: int, end: int, theta: float = 10000.0, dtype=jnp.float32
):
# returns:
# cos, sin: (end, dim)
freqs = 1.0 / \
(theta ** (np.arange(0, dim, 2, dtype=np.float32)[: (dim // 2)] / dim))
t = np.arange(end, dtype=np.float32) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
cos, sin = np.cos(freqs), np.sin(freqs)
return jnp.array(cos, dtype=dtype), jnp.array(sin, dtype=dtype)
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id=None):
# inputs:
# x: (batch_size, seq_len, num_heads, head_dim)
# cos, sin: (seq_len, head_dim)
# position_id: (batch_size, seq_len)
# returns:
# x: (batch_size, seq_len, num_heads, head_dim)
if position_id is None:
q_pos = jnp.arange(q.shape[1])[None, :]
k_pos = jnp.arange(k.shape[1])[None, :]
else:
q_pos = position_id
k_pos = position_id
cos_q = jnp.take(cos, q_pos, axis=0)[:, :, None, :]
sin_q = jnp.take(sin, q_pos, axis=0)[:, :, None, :]
q = (q * cos_q) + (rotate_half(q) * sin_q)
cos_k = jnp.take(cos, k_pos, axis=0)[:, :, None, :]
sin_k = jnp.take(sin, k_pos, axis=0)[:, :, None, :]
k = (k * cos_k) + (rotate_half(k) * sin_k)
return q, k
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return jnp.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb_index2(q, k, cos, sin, position_id=None):
# inputs:
# x: (batch_size, seq_len, num_heads, head_dim)
# cos, sin: (seq_len, head_dim)
# position_id: (batch_size, seq_len)
# returns:
# x: (batch_size, seq_len, num_heads, head_dim)
if position_id is None:
q_pos = jnp.arange(q.shape[1])[None, :]
k_pos = jnp.arange(k.shape[1])[None, :]
else:
q_pos = position_id
k_pos = position_id
cos_q = jnp.take(cos, q_pos, axis=0)[:, :, None, :]
sin_q = jnp.take(sin, q_pos, axis=0)[:, :, None, :]
q = apply_cos_sin(q, cos_q, sin_q)
cos_k = jnp.take(cos, k_pos, axis=0)[:, :, None, :]
sin_k = jnp.take(sin, k_pos, axis=0)[:, :, None, :]
k = apply_cos_sin(k, cos_k, sin_k)
return q, k
def apply_cos_sin(x, cos, sin):
dim = x.shape[-1]
x1 = x[..., :dim // 2]
x2 = x[..., dim // 2:]
x1 = x1.reshape(x1.shape[:-1] + (-1, 2))
x1 = jnp.stack((x1[..., 0] * cos - x1[..., 1] * sin,
x1[..., 1] * cos + x1[..., 0] * sin), axis=-1)
x1 = x1.reshape(x2.shape)
x = jnp.concatenate((x1, x2), axis=-1)
return x
def make_apply_rope(head_dim, max_len, dtype, multi_query=False):
if multi_query:
cos, sin = precompute_freqs_cis2(
dim=head_dim // 2, end=max_len, dtype=dtype)
def add_pos(q, k, p=None): return apply_rotary_pos_emb_index2(
q, k, cos, sin, p)
else:
cos, sin = precompute_freqs_cis(
dim=head_dim, end=max_len, dtype=dtype)
def add_pos(q, k, p=None): return apply_rotary_pos_emb_index(
q, k, cos, sin, p)
return add_pos
def replicate_for_multi_query(x, num_heads):
src_num_heads, head_dim = x.shape[-2:]
x = jnp.repeat(x, num_heads // src_num_heads, axis=-2)
# x = jnp.expand_dims(x, axis=-2)
# x = jnp.tile(x, (1, 1, 1, num_heads // src_num_heads, 1))
# x = jnp.reshape(x, (*x.shape[:2], num_heads, head_dim))
return x
def dot_product_attention_weights(
query: Array,
key: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention weights given query and key.
Used by :func:`dot_product_attention`, which is what you'll most likely use.
But if you want access to the attention weights for introspection, then
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``True``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
Returns:
Output of shape ``[batch..., num_heads, q_length, kv_length]``.
"""
query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype
assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, big_neg, attn_weights)
# normalize the attention weights
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + \
attn_weights.shape[-2:]
keep = random.bernoulli(
dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(
dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
attn_weights = attn_weights * multiplier
return attn_weights
def dot_product_attention(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
value: values to be used in attention with shape of ``[batch..., kv_length,
num_heads, v_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``True``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see ``jax.lax.Precision`
for details.
Returns:
Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
"""
query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
attn_weights = dot_product_attention_weights(
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
deterministic,
dtype,
precision,
)
# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
)
class MultiheadAttention(nn.Module):
features: int
num_heads: int
max_len: Optional[int] = None
multi_query_groups: Optional[int] = None
dtype: Optional[Dtype] = None
param_dtype: Optional[Dtype] = jnp.float32
broadcast_dropout: bool = False
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
qkv_bias: bool = True
out_bias: bool = True
rope: bool = False
@nn.compact
def __call__(
self,
query: Array,
key: Array,
value: Array,
key_padding_mask: Optional[Array] = None,
attn_mask: Optional[Array] = None,
):
r"""
Parameters
----------
query: Array, shape [batch, q_len, features]
Query features.
key: Array, shape [batch, kv_len, features]
Key features.
value: Array, shape [batch, kv_len, features]
Value features.
key_padding_mask: Optional[Array], shape [batch, kv_len]
Mask to indicate which keys have zero padding.
attn_mask: Optional[Array], shape [batch, 1, q_len, kv_len]
Mask to apply to attention scores.
Returns
-------
out: Array, shape [batch, q_len, features]
Output features.
"""
features = self.features
if self.rope:
assert self.max_len is not None, "max_len must be provided for rope"
multi_query = self.multi_query_groups is not None
assert (
features % self.num_heads == 0
), "Memory dimension must be divisible by number of heads."
head_dim = features // self.num_heads
query = nn.DenseGeneral(
features=(self.num_heads, head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.qkv_bias,
axis=-1,
name="query",
)(query)
kv_num_heads = self.num_heads
if multi_query:
kv_num_heads = self.multi_query_groups
kv_dense = [
functools.partial(
nn.DenseGeneral,
features=(kv_num_heads, head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.qkv_bias,
axis=-1,
) for i in range(2)
]
key = kv_dense[0](name="key")(key)
value = kv_dense[1](name="value")(value)
if multi_query:
key = replicate_for_multi_query(key, self.num_heads)
value = replicate_for_multi_query(value, self.num_heads)
if self.rope:
add_pos = make_apply_rope(
head_dim, self.max_len, self.dtype, multi_query)
else:
def add_pos(q, k, p=None): return (q, k)
query, key = add_pos(query, key)
dropout_rng = None
if self.dropout_rate > 0 and not self.deterministic:
dropout_rng = self.make_rng("dropout")
deterministic = False
else:
deterministic = True
if key_padding_mask is not None:
key_padding_mask = key_padding_mask[:, None, None, :]
if attn_mask is not None:
mask = attn_mask
if key_padding_mask is not None:
mask = jnp.logical_or(mask, key_padding_mask)
else:
mask = key_padding_mask
x = dot_product_attention(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=deterministic,
dtype=self.dtype,
)
out = nn.DenseGeneral(
features=features,
axis=(-2, -1),
use_bias=self.out_bias,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dtype=self.dtype,
param_dtype=self.param_dtype,
name="out",
)(x)
return out
class MlpBlock(nn.Module):
intermediate_size: Optional[int] = None
activation: str = "gelu"
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
use_bias: bool = True
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
@nn.compact
def __call__(self, inputs):
assert self.activation in [
"gelu", "gelu_new", "relu"], "activation must be gelu, gelu_new or relu"
intermediate_size = self.intermediate_size or 4 * inputs.shape[-1]
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(2)
]
actual_out_dim = inputs.shape[-1]
x = dense[0](
features=intermediate_size,
name="fc_1",
)(inputs)
if self.activation == "gelu":
x = nn.gelu(x, approximate=False)
elif self.activation == "gelu_new":
x = nn.gelu(x, approximate=True)
elif self.activation == "relu":
x = nn.relu(x)
x = dense[1](
features=actual_out_dim,
name="fc_2",
)(x)
return x
class GLUMlpBlock(nn.Module):
intermediate_size: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
use_bias: bool = False
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
shard=self.shard,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class EncoderLayer(nn.Module):
n_heads: int
intermediate_size: Optional[int] = None
activation: str = "relu"
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
layer_norm_epsilon: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
features=x.shape[-1],
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + inputs
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_2")(x)
y = MlpBlock(
intermediate_size=self.intermediate_size,
activation=self.activation,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = x + y
return y
class DecoderLayer(nn.Module):
n_heads: int
intermediate_size: Optional[int] = None
activation: str = "relu"
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
layer_norm_epsilon: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
features = tgt.shape[-1]
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(tgt)
x = MultiheadAttention(
features=features,
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + tgt
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_2")(x)
y = MultiheadAttention(
features=features,
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = y + x
z = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_3")(y)
z = MlpBlock(
intermediate_size=self.intermediate_size,
activation=self.activation,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="mlp")(z)
z = nn.Dropout(rate=self.resid_pdrop)(
z, deterministic=self.deterministic
)
z = y + z
return z
class LlamaEncoderLayer(nn.Module):
n_heads: int
intermediate_size: int
n_positions: int = 512
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
rms_norm_eps: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
features=x.shape[-1],
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + inputs
y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x)
y = GLUMlpBlock(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
use_bias=False,
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = x + y
return y
class LlamaDecoderLayer(nn.Module):
n_heads: int
intermediate_size: int
n_positions: int = 512
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
rms_norm_eps: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
features = tgt.shape[-1]
x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(tgt)
x = MultiheadAttention(
features=features,
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + tgt
y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x)
y = MultiheadAttention(
features=features,
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = y + x
z = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_3")(y)
z = GLUMlpBlock(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
use_bias=False,
name="mlp")(z)
z = nn.Dropout(rate=self.resid_pdrop)(
z, deterministic=self.deterministic
)
z = y + z
return z
import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x))
return x.sum() / valid.sum()
def masked_normalize(x, valid, epsilon=1e-8):
x = jnp.where(valid, x, jnp.zeros_like(x))
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
\ No newline at end of file
...@@ -49,6 +49,9 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv ...@@ -49,6 +49,9 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
entropy_loss = masked_mean(entropy, valid) entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad() optimizer.zero_grad()
if scaler is None:
loss.backward()
else:
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
......
...@@ -6,55 +6,7 @@ import pickle ...@@ -6,55 +6,7 @@ import pickle
import optree import optree
import torch import torch
class RecordEpisodeStatistics(gym.Wrapper): from ygoai.rl.env import RecordEpisodeStatistics
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
def split_param_groups(model, regex): def split_param_groups(model, regex):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment