Commit 9a2a21e3 authored by sbl1996@126.com's avatar sbl1996@126.com

Merge branch 'enhance_history'

parents d8cbf274 e79a43ef
*.pt
*.ptj
*.pkl
# Xmake cache
.xmake/
......
......@@ -88,4 +88,7 @@
## History Actions
- 0,1: card id, uint16 -> 2 uint8
- others same as legal actions
- 2-12 same as legal actions
- 13: player, discrete, 0: me, 1: oppo
- 14: turn, discrete, trunc to 3
......@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
......@@ -41,7 +35,7 @@ class Args:
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
......@@ -60,37 +54,43 @@ class Args:
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint1: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent"""
checkpoint2: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
def predict_step(agent, obs):
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
return probs
framework: Optional[Literal["torch", "jax"]] = None
if __name__ == "__main__":
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
......@@ -101,14 +101,20 @@ if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
if args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint1 else "torch"
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
else:
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
......@@ -124,36 +130,48 @@ if __name__ == "__main__":
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.checkpoint1.endswith(".ptj"):
agent1 = torch.jit.load(args.checkpoint1)
agent2 = torch.jit.load(args.checkpoint2)
else:
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
if args.framework == 'torch':
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint1.endswith(".ptj"):
agent1 = torch.jit.load(args.checkpoint1)
agent2 = torch.jit.load(args.checkpoint2)
else:
if args.optimize:
# count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
def get_probs(agent, obs):
with torch.no_grad():
return torch.softmax(agent(obs)[0], dim=-1)
if args.compile:
get_probs = torch.compile(get_probs, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
......@@ -161,9 +179,58 @@ if __name__ == "__main__":
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
def predict_fn(agent, obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs = get_probs(agent, obs)
probs = probs.cpu().numpy()
return probs
predict_fn1 = lambda obs: predict_fn(agent1, obs)
predict_fn2 = lambda obs: predict_fn(agent2, obs)
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
print(jax.tree.leaves(params)[0].devices())
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(params, obs):
probs = get_probs(params, obs)
return np.array(probs)
predict_fn1 = lambda obs: predict_fn(params1, obs)
predict_fn2 = lambda obs: predict_fn(params2, obs)
obs, infos = envs.reset()
next_to_play_ = infos['to_play']
next_to_play = infos['to_play']
episode_rewards = []
episode_lengths = []
......@@ -174,12 +241,10 @@ if __name__ == "__main__":
start = time.time()
start_step = step
num_envs_half = num_envs // 2
player1_ = np.concatenate([
np.zeros(num_envs_half, dtype=np.int64),
np.ones(num_envs - num_envs_half, dtype=np.int64)
player1 = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
player1 = torch.from_numpy(player1_).to(device=device)
model_time = env_time = 0
while True:
......@@ -189,21 +254,24 @@ if __name__ == "__main__":
model_time = env_time = 0
_start = time.time()
next_to_play = torch.from_numpy(next_to_play_).to(device=device)
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs1 = predict_step(agent1, obs).clone()
probs2 = predict_step(agent2, obs).clone()
if args.num_envs != 1:
probs1 = predict_fn1(obs)
probs2 = predict_fn2(obs)
probs = np.where((next_to_play == player1)[:, None], probs1, probs2)
else:
if (next_to_play == player1).all():
probs = predict_fn1(obs)
else:
probs = predict_fn2(obs)
probs = torch.where((next_to_play == player1)[:, None], probs1, probs2)
probs = probs.cpu().numpy()
actions = probs.argmax(axis=1)
model_time += time.time() - _start
to_play = next_to_play_
to_play = next_to_play
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play_ = infos['to_play']
next_to_play = infos['to_play']
env_time += time.time() - _start
step += 1
......@@ -211,11 +279,10 @@ if __name__ == "__main__":
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == player1_[idx] else -1
pl = 1 if to_play[idx] == player1[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
......@@ -223,8 +290,8 @@ if __name__ == "__main__":
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# Only when num_envs=1, we switch the player here
if args.verbose:
player1_ = 1 - player1_
player1 = 1 - player1
if len(episode_lengths) >= args.num_episodes:
......
......@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
......@@ -41,7 +35,7 @@ class Args:
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
......@@ -50,8 +44,6 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
record: bool = False
"""whether to record the game as YGOPro replays"""
......@@ -67,27 +59,36 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load"""
checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
if __name__ == "__main__":
args = tyro.cli(Args)
......@@ -102,7 +103,6 @@ if __name__ == "__main__":
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
......@@ -113,15 +113,20 @@ if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
if args.agent:
if args.checkpoint and args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint else "torch"
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
else:
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
......@@ -136,15 +141,22 @@ if __name__ == "__main__":
player=args.player,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
play_mode='human' if args.play else ('bot' if args.bot_type == "greedy" else "random"),
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.agent:
if args.checkpoint and args.checkpoint.endswith(".ptj"):
if args.framework == 'torch':
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
......@@ -154,13 +166,12 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
......@@ -191,6 +202,48 @@ if __name__ == "__main__":
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
def predict_fn(obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits = agent(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
return probs
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(obs):
probs = get_probs(params, obs)
return np.array(probs)
print(f"loaded checkpoint from {args.checkpoint}")
obs, infos = envs.reset()
next_to_play = infos['to_play']
......@@ -210,16 +263,11 @@ if __name__ == "__main__":
start_step = step
model_time = env_time = 0
if args.agent:
if args.framework:
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
probs = predict_fn(obs)
if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
print(f"{values[0].item():.4f}")
actions = probs.argmax(axis=1)
model_time += time.time() - _start
else:
......@@ -228,13 +276,6 @@ if __name__ == "__main__":
else:
actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play = next_to_play
_start = time.time()
......@@ -249,15 +290,7 @@ if __name__ == "__main__":
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward < 0:
win = 0
else:
win = 1
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
......
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100
"""the frequency of saving the model"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 64
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 20
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
gradient_accumulation_steps: int = 1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [1])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logitss: list
rewards: list
learns: list
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def apply_fn(
params: flax.core.FrozenDict,
next_obs,
):
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
def get_action(
params: flax.core.FrozenDict,
next_obs,
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs,
key: jax.random.PRNGKey,
):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0]
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
return next_obs, action, logits, key
# put data in the last index
envs.async_reset()
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
storage = []
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
next_to_play = None
learn = np.ones(args.local_num_envs, dtype=np.bool_)
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
num_steps_with_bootstrap = (
args.num_steps + int(len(storage) == 0)
) # num_steps + 1 to get the states for value bootstrapping.
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for _ in range(0, num_steps_with_bootstrap):
global_step += args.local_num_envs * n_actors * args.world_size
_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
next_reward = np.where(learn, next_reward, -next_reward)
env_time += time.time() - _start
to_play = next_to_play
next_to_play = info["to_play"]
learn = next_to_play == ai_player1
inference_time_start = time.time()
next_obs, action, logits, key = sample_action(params, next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
envs.send(cpu_action)
storage.append(
Transition(
obs=next_obs,
dones=next_done,
actions=action,
logitss=logits,
rewards=next_reward,
learns=learn,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
# move bootstrapping step to the beginning of the next update
storage = storage[-1:]
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns) if len(avg_ep_returns) > 0 else 0
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(
args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(
args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) *
args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(
len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * \
args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (
args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [
str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % (
"\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args)
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=args.gradient_accumulation_steps,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
agent_state = flax.jax_utils.replicate(
agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict,
obs: np.ndarray,
):
logits, value, valid = create_agent(args).apply(params, obs)
return logits, value.squeeze(-1), valid
def impala_loss(params, obs, actions, logitss, rewards, dones, learns):
# (num_steps + 1, local_num_envs // n_mb))
discounts = (1.0 - dones) * args.gamma
policy_logits, newvalue, valid = jax.vmap(
get_logits_and_value, in_axes=(None, 0))(params, obs)
newvalue = jnp.where(learns, newvalue, -newvalue)
v_t = newvalue[1:]
# Remove bootstrap timestep from non-timesteps.
v_tm1 = newvalue[:-1]
policy_logits = policy_logits[:-1]
logitss = logitss[:-1]
actions = actions[:-1]
mask = 1.0 - dones
rewards = rewards[1:]
discounts = discounts[1:]
mask = mask[:-1]
rhos = rlax.categorical_importance_sampling_ratios(
policy_logits, logitss, actions)
vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, rhos)
jax.debug.print("R {}", jnp.where(dones[1:-1, :2], rewards[:-1, :2], 0).T)
jax.debug.print("E {}", jnp.where(dones[1:-1, :2], vtrace_returns.errors[:-1, :2] * 100, vtrace_returns.errors[:-1, :2]).T)
jax.debug.print("V {}", v_tm1[:-1, :2].T)
T = v_tm1.shape[0]
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
else:
advs = vtrace_returns.q_estimate - v_tm1
if args.ppo_clip:
pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
rhos, advs, mask) * T
pg_loss = jnp.sum(pg_loss)
else:
pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs
pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)(
policy_logits, actions, pg_advs, mask) * T
pg_loss = jnp.sum(pg_loss)
baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask)
ent_loss = jax.vmap(rlax.entropy_loss, in_axes=1)(
policy_logits, mask) * T
ent_loss = jnp.sum(ent_loss)
n_samples = jnp.sum(mask)
pg_loss = pg_loss / n_samples
baseline_loss = baseline_loss / n_samples
ent_loss = ent_loss / n_samples
total_loss = pg_loss
total_loss += args.vf_coef * baseline_loss
total_loss += args.ent_coef * ent_loss
return total_loss, (pg_loss, baseline_loss, ent_loss)
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List[Transition],
key: jax.random.PRNGKey,
):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True)
def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_logitss, mb_rewards, mb_dones, mb_learns = minibatch
(loss, (pg_loss, v_loss, entropy_loss)), grads = impala_loss_grad_fn(
agent_state.params,
mb_obs,
mb_actions,
mb_logitss,
mb_rewards,
mb_dones,
mb_learns,
)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss)
n_mb = args.num_minibatches * args.gradient_accumulation_steps
storage_obs = {
k: jnp.array(jnp.split(v, n_mb, axis=1))
for k, v in storage.obs.items()
}
agent_state, (loss, pg_loss, v_loss, entropy_loss) = jax.lax.scan(
update_minibatch,
agent_state,
(
# (num_steps + 1, local_num_envs) => (n_mb, num_steps + 1, local_num_envs // n_mb)
storage_obs,
jnp.array(jnp.split(storage.actions, n_mb, axis=1)),
jnp.array(jnp.split(storage.logitss, n_mb, axis=1)),
jnp.array(jnp.split(storage.rewards, n_mb, axis=1)),
jnp.array(jnp.split(storage.dones, n_mb, axis=1)),
jnp.array(jnp.split(storage.learns, n_mb, axis=1)),
),
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(
entropy_loss, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(
unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
rollout_queue_get_time.append(
time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(
unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads +
thread_id].put(device_params)
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time",
np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time",
time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size",
rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size",
params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss",
v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss",
pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy",
entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints/{run_name}"
os.makedirs(ckpt_dir, exist_ok=True)
model_path = ckpt_dir + "/agent.cleanrl_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
vars(args),
unreplicated_params,
]
)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
rewards: list
mains: list
probs: list
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, next_obs):
return create_agent(args).apply(params, next_obs)[0]
def get_action(
params: flax.core.FrozenDict, next_obs):
return get_logits(params, next_obs).argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, key: jax.random.PRNGKey):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, action, logprob, probs, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, action, logprob, probs, key = sample_action(
params, cached_next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logprobs=logprob,
rewards=next_reward,
mains=main,
probs=probs,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_main))
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args)
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict, obs, actions,
):
logits, value, valid = create_agent(args).apply(params, obs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions)
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
):
def reshape_minibatch(x, num_minibatches, multi_step=False):
N = num_minibatches
if multi_step:
x = jnp.reshape(x, (N, -1) + x.shape[2:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_done, next_main = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main]
]
# reorder storage of individual players
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = (storage.mains == next_main).astype(jnp.int32)
indices = jnp.argsort(T[:, None] + mains * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(mains, axis=0))
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
# split minibatches for recompute values
n_mbs = args.num_minibatches // 8
split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, multi_step=True), storage.obs)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_inputs):
values = create_agent(args).apply(
agent_state.params, mb_inputs)[1].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, next_obs)[1].squeeze(-1)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray):
x = x.reshape(-1, *x.shape[2:])
x = jax.random.permutation(subkey, x)
return reshape_minibatch(x, args.num_minibatches)
shuffled_storage, shuffled_advantages, shuffled_target_values = jax.tree.map(
convert_data, (storage, advantages, target_values))
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_storage.obs,
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
*sharded_data,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
rewards: list
mains: list
probs: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, rstate1, rstate2, action, logprob, probs, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
initial_rstate1, initial_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, next_rstate1, next_rstate2, action, logprob, probs, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logprobs=logprob,
rewards=next_reward,
mains=main,
probs=probs,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
# initial_rstate1: main, initial_rstate2: opponent
# init rstate1: == next_main, init rstate2: != next_main
init_rstate1 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], x, y), initial_rstate1, initial_rstate2)
init_rstate2 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], y, x), initial_rstate1, initial_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main))
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict, inputs, actions,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions)
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_rstate: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
):
def reshape_minibatch(x, num_minibatches, multi_step=False):
N = num_minibatches
if multi_step:
x = jnp.reshape(x, (args.num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs, next_rstate, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_obs, sharded_next_rstate, sharded_init_rstate1, sharded_init_rstate2]
]
next_done, next_main = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main]
]
# reorder storage of individual players
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = (storage.mains == next_main).astype(jnp.int32)
indices = jnp.argsort(T[:, None] + mains * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(mains, axis=0))
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
# split minibatches for recompute values
n_mbs = args.num_minibatches // 4
split_init_rstate = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs),
(init_rstate1, init_rstate2))
split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, multi_step=True),
(storage.obs, storage.dones, switch))
split_inputs = split_init_rstate + split_inputs
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_inputs):
values = create_agent(args, multi_step=True).apply(
agent_state.params, mb_inputs)[2].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs)
values = values.reshape((n_mbs, args.num_steps, -1)).transpose(1, 0, 2)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, (next_rstate, next_obs))[2].squeeze(-1)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray, multi_step):
x = jax.random.permutation(subkey, x, axis=1)
return reshape_minibatch(x, args.num_minibatches, multi_step)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2))
shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map(
partial(convert_data, multi_step=True), (storage, switch, advantages, target_values))
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_switch,
),
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
*sharded_data,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
......@@ -3,7 +3,8 @@ import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional
import ygoenv
import numpy as np
......@@ -51,10 +52,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -73,18 +72,25 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -92,17 +98,11 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
......@@ -124,7 +124,7 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
......@@ -140,8 +140,27 @@ class Args:
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
......@@ -155,6 +174,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
......@@ -168,7 +188,7 @@ def main():
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
torchrun_setup('nccl', local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
......@@ -203,43 +223,17 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -247,7 +241,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
......@@ -259,6 +253,13 @@ def main():
if args.embedding_file:
agent.freeze_embeddings()
if args.fix_target:
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
......@@ -274,13 +275,19 @@ def main():
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.fix_target:
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
else:
traced_model_t = traced_model
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
......@@ -292,6 +299,7 @@ def main():
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
......@@ -308,13 +316,12 @@ def main():
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
......@@ -332,6 +339,10 @@ def main():
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
if args.fix_target:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -343,10 +354,6 @@ def main():
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
......@@ -371,7 +378,7 @@ def main():
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -389,11 +396,15 @@ def main():
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, -value)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != ai_player1, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
......@@ -403,7 +414,7 @@ def main():
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = traced_model(v_obs)[1].reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
......@@ -420,8 +431,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.fix_target:
b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -443,9 +457,6 @@ def main():
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
......@@ -462,7 +473,6 @@ def main():
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
......@@ -489,7 +499,29 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
if args.fix_target:
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
......
......@@ -2,27 +2,30 @@ import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from queue import Queue
from dataclasses import dataclass, field
from typing import Optional, List
import ygoenv
import numpy as np
import optree
import numpy as np
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, load_embeddings
from ygoai.rl.agent2 import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
@dataclass
......@@ -35,6 +38,8 @@ class Args:
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
port: int = 29500
"""the port to use for distributed training"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
......@@ -51,10 +56,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -67,24 +70,26 @@ class Args:
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
local_num_envs: int = 128
"""the number of parallel game environments per actor"""
num_actor_threads: int = 1
"the number of actor threads to use"
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
gae_lambda: float = 0.98
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
"""the mini-batch size"""
num_minibatches: int = 4
"the number of mini-batches"
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -92,19 +97,24 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
actor_device_ids: List[int] = field(default_factory=lambda: [0])
"the device ids that actor workers will use"
learner_device_ids: List[int] = field(default_factory=lambda: [0])
"the device ids that learner workers will use"
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
local_torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2)"""
local_env_threads: Optional[int] = 16
"""the number of threads to use for envpool in each actor"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
......@@ -120,228 +130,102 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
num_envs: int = 0
"""the number of parallel game environments"""
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_minibatches: int = 0
"""the number of mini-batches (computed in runtime)"""
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def actor(
args,
a_rank,
rollout_queues: List[Queue],
param_queue: Queue,
run_name,
device_thread_id,
):
if a_rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
writer = None
torch.set_num_threads(args.local_torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
device = torch.device(f"cuda:{device_thread_id}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
envs = make_env(args, args.local_num_envs, args.local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if a_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
envs_per_thread = args.local_num_envs // args.local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, lstm_state, mb_dones, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid, _ = agent(mb_obs, lstm_state, mb_dones)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
def predict_step(agent: Agent, next_obs, next_lstm_state, next_done):
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid, next_lstm_state = agent(next_obs, next_lstm_state, next_done)
return logits, value, next_lstm_state
logits, value, valid = agent(next_obs)
return logits, value
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs, next_lstm_state, next_done), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
predict_step = torch.compile(predict_step, mode=args.compile)
agent_r = agent
else:
agent_r = agent
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
......@@ -352,44 +236,33 @@ def main():
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8)
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, dtype=next_to_play.dtype)
next_value1 = 0
next_value2 = 0
for iteration in range(1, args.num_iterations + 1):
initial_lstm_state = (next_lstm_state[0].clone(), next_lstm_state[1].clone())
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
params_buffer = param_queue.get()[1]
for iteration in range(1, args.num_iterations):
if iteration > 2:
param_queue.get()
agent.load_state_dict(params_buffer)
model_time = 0
env_time = 0
collect_start = time.time()
agent.eval()
for step in range(0, args.num_steps):
global_step += args.num_envs
while step < args.num_steps:
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
......@@ -397,7 +270,7 @@ def main():
learns[step] = learn
_start = time.time()
logits, value, next_lstm_state = predict_step(traced_model, next_obs, next_lstm_state, next_done)
logits, value = predict_step(agent_r, next_obs)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -417,10 +290,13 @@ def main():
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start
rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool)
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
global_step += args.num_envs
if not writer:
continue
......@@ -436,7 +312,7 @@ def main():
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -446,241 +322,247 @@ def main():
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs, next_lstm_state, next_done)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device)
value = predict_step(agent_r, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(args.num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# delta1 = reward1 + args.gamma * nextvalues1 - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# advantages[t] = lastgaelam1_
# nextvalues1 = values[t]
# lastgaelam1 = lastgaelam_
# else:
# if dones[t+1]:
# reward2 = rewards[t]
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
#
# reward1 = -rewards[t]
# done_used1 = False
# else:
# if not done_used2:
# reward2 = reward2
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
# else:
# reward2 = rewards[t]
# reward1 = reward1
# delta2 = reward2 + args.gamma * nextvalues2 - values[t]
# lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
# advantages[t] = lastgaelam2_
# nextvalues2 = values[t]
# lastgaelam2 = lastgaelam_
learn1 = learns[t]
learn2 = ~learn1
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + args.gamma * nextvalues1 - values[t]
delta2 = reward2 + args.gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
returns = advantages + values
bootstrap_time = time.time() - _start
step = 0
for iq, rq in enumerate(rollout_queues):
n_e = args.local_num_envs // len(rollout_queues)
start = iq * n_e
end = start + n_e
data = []
d = optree.tree_map(lambda x: x[:, start:end],
(obs, actions, logprobs, rewards, dones, values, learns))
for v in d:
data.append(v)
for v in [next_done, nextvalues1, nextvalues2]:
data.append(v[start:end])
rq.put(data)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if a_rank == 0:
fprint(f"SPS: {SPS}")
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
eval_envs, agent_r, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
# if args.world_size > 1:
# dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return = eval_stats.cpu().numpy()
if a_rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
def learner(
args: Args,
l_rank,
rollout_queue: Queue,
param_queue: Queue,
run_name,
ckpt_dir,
device_thread_id,
):
num_learners = len(args.learner_device_ids)
if len(args.learner_device_ids) > 1:
setup('nccl', l_rank, num_learners, args.port)
local_batch_size = args.local_batch_size // num_learners
local_minibatch_size = args.local_minibatch_size // num_learners
torch.set_num_threads(args.local_torch_threads)
torch.set_float32_matmul_precision('high')
args.seed += l_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - l_rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{device_thread_id}" if torch.cuda.is_available() and args.cuda else "cpu")
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
from ygoai.rl.ppo import train_step
if args.compile:
train_step = torch.compile(train_step, mode=args.compile)
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
global_step = 0
first_in_group = l_rank % (num_learners // (len(args.actor_device_ids) * args.num_actor_threads)) == 0
if first_in_group:
param_queue.put(("Init", agent.state_dict()))
for iteration in range(1, args.num_iterations):
bootstrap_start = time.time()
_start = time.time()
data = rollout_queue.get()
wait_time = time.time() - _start
obs, actions, logprobs, rewards, dones, values, learns, next_done, nextvalues1, nextvalues2 \
= optree.tree_map(lambda x: x.to(device=device, non_blocking=True), data)
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - bootstrap_start
_start = time.time()
agent.train()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_dones = dones.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
b_actions = actions[:args.num_steps].flatten(0, 1)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network
assert args.local_num_envs % args.num_minibatches == 0
envsperbatch = args.local_num_envs // args.num_minibatches # minibatch_size // num_steps
envinds = np.arange(args.local_num_envs)
flatinds = np.arange(args.local_batch_size).reshape(args.num_steps, args.local_num_envs)
b_inds = np.arange(local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(envinds)
for start in range(0, args.local_num_envs, envsperbatch):
end = start + envsperbatch
mbenvinds = envinds[start:end]
mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
np.random.shuffle(b_inds)
for start in range(0, local_batch_size, local_minibatch_size):
end = start + local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = train_step(
agent, scaler, mb_obs, (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
b_dones[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size)
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, num_learners)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
train_time = time.time() - _start
global_step += args.num_envs
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
if first_in_group:
param_queue.put(("Done", None))
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
if l_rank == 0:
train_time = time.time() - _start
fprint(f"train_time={train_time:.4f}, bootstrap_time={bootstrap_time:.4f}, wait_time={wait_time:.4f}")
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
fprint(f"global_step={global_step}, value_loss={v_loss.item():.4f}, policy_loss={pg_loss.item():.4f}, entropy_loss={entropy_loss.item():.4f}")
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
episode_lengths = []
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
e_dones_ = np.zeros(local_eval_num_envs, dtype=np.bool_)
e_next_lstm_state = (
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
)
# writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
# writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
# writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
# writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
# writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
# writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
# writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
# writer.add_scalar("losses/explained_variance", explained_var, global_step)
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_dones = to_tensor(e_dones_, dtype=torch.bool)
e_logits, _, e_next_lstm_state = predict_step(traced_model, e_obs, e_next_lstm_state, e_dones)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones_, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones_):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__":
world_size = int(os.environ.get("WORLD_SIZE", 1))
# Eval with old model
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
args.world_size = 1
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.local_env_threads * args.num_actor_threads * len(args.actor_device_ids)
args.local_torch_threads = args.local_torch_threads or int(os.getenv("OMP_NUM_THREADS", "2"))
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
rollout_queues = []
param_queues = []
actor_processes = []
learner_processes = []
num_actors = len(args.actor_device_ids) * args.num_actor_threads
num_learners = len(args.learner_device_ids)
assert num_learners % num_actors == 0, "num_learners must be divisible by num_actors"
group_size = num_learners // num_actors
for i, device_id in enumerate(args.actor_device_ids):
for j in range(args.num_actor_threads):
a_rank = i * args.num_actor_threads + j
param_queues.append(mp.Queue(maxsize=1))
rollout_queues_ = [mp.Queue(maxsize=1) for _ in range(group_size)]
rollout_queues.extend(rollout_queues_)
p = mp.Process(
target=actor,
args=(args, a_rank, rollout_queues_, param_queues[-1], run_name, device_id),
)
actor_processes.append(p)
p.start()
for i, device_id in enumerate(args.learner_device_ids):
param_queue = param_queues[i // group_size]
rollout_queue = rollout_queues[i]
p = mp.Process(
target=learner,
args=(args, i, rollout_queue, param_queue, run_name, ckpt_dir, device_id),
)
learner_processes.append(p)
p.start()
if __name__ == "__main__":
main()
for p in actor_processes + learner_processes:
p.join()
\ No newline at end of file
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional
import ygoenv
......@@ -22,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_self
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
......@@ -52,10 +51,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -74,15 +71,21 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
self_play_prob: float = 0.6
"""the probability of self play"""
max_lp: int = 6
"""the maximum number of LP to add model to the pool"""
iter_per_lp: int = 1000
"""the number of iterations per learning phase"""
target_sample_iter: int = 10
"""the number of iterations to sample the target model"""
minibatch_size: int = 256
"""the mini-batch size"""
......@@ -90,7 +93,7 @@ class Args:
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -98,15 +101,11 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
......@@ -128,7 +127,7 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
......@@ -144,8 +143,31 @@ class Args:
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def update_running_mean(mean, value, count):
return mean + (value - mean) / count
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
......@@ -159,8 +181,12 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
......@@ -169,7 +195,19 @@ def main():
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
torchrun_setup('nccl', local_rank)
def sync_var(var, dtype=torch.float32, reduce='first'):
ts = torch.tensor(var, dtype=dtype, device=device)
if reduce == 'mean':
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.AVG)
else:
if rank != 0:
ts = torch.zeros_like(ts)
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.SUM)
return ts.cpu().numpy()
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
......@@ -204,43 +242,17 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -248,7 +260,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
......@@ -260,24 +272,23 @@ def main():
if args.embedding_file:
agent.freeze_embeddings()
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value
history = []
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
......@@ -285,51 +296,79 @@ def main():
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
if args.checkpoint:
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
def sample_target(history):
ts = []
for i in range(args.target_sample_iter):
if len(history) == 0 or random.random() < args.self_play_prob:
ts.append(-1)
else:
ts.append(random.randint(0, len(history) - 1))
ts.sort(reverse=True)
return sync_var(ts, dtype=torch.int64).tolist()
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = [0]
avg_win_rates = [0]
n_episodes = [0]
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0
ai_player1 = to_tensor(ai_player1_, device)
next_value1 = next_value2 = 0
step = 0
ts = []
lp_count = 0
for iteration in range(1, args.num_iterations + 1):
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
frac = 1.0 - (iteration % args.iter_per_lp) / args.iter_per_lp
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
if iteration % args.iter_per_lp == 0:
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_value1 = next_value2 = 0
step = 0
ts = []
if len(ts) == 0:
ts = sample_target(history)
t_idx = ts.pop()
selfplay = t_idx == -1
if not selfplay:
traced_model_t = history[t_idx]
model_time = 0
env_time = 0
collect_start = time.time()
for step in range(0, args.num_steps):
while step < args.collect_length:
global_step += args.num_envs
for key in obs:
......@@ -339,7 +378,11 @@ def main():
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, traced_model_t, next_obs, learn)
logits, value = predict_step(traced_model, next_obs)
if not selfplay:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -352,7 +395,8 @@ def main():
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value = torch.where(learn, value, next_value) * next_nonterminal
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
......@@ -362,9 +406,7 @@ def main():
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
if not writer:
continue
step += 1
for idx, d in enumerate(next_done_):
if d:
......@@ -372,12 +414,14 @@ def main():
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if len(history) == 0 or not selfplay:
n_episodes[t_idx] += 1
avg_ep_returns[t_idx] = update_running_mean(avg_ep_returns[t_idx], episode_reward, n_episodes[t_idx])
avg_win_rates[t_idx] = update_running_mean(avg_win_rates[t_idx], win, n_episodes[t_idx])
if random.random() < args.log_p:
if writer and random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -390,30 +434,50 @@ def main():
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1)
value_t = traced_model_t(next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
advantages = bootstrap_value_self(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
value = predict_step(traced_model, next_obs)[1].reshape(-1)
if not selfplay:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - _start
_start = time.time()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if selfplay:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -435,9 +499,13 @@ def main():
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start
if local_rank == 0:
......@@ -447,7 +515,6 @@ def main():
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
......@@ -474,30 +541,32 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if (iteration + 1) % args.iter_per_lp == 0:
lp_count += 1
win_rates = sync_var(avg_win_rates, dtype=torch.float32, reduce='mean')
if len(history) == 0 or np.all(win_rates > args.update_win_rate) or lp_count >= args.max_lp:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
lp_count = 0
if rank == 0:
version = len(history)
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
fprint(f"model v{version} added to the pool, win_rates={win_rates}")
else:
if rank == 0:
fprint(f"win_rates={win_rates}, not updating the pool")
avg_ep_returns = [0] * len(history)
avg_win_rates = [0] * len(history)
n_episodes = [0] * len(history)
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
......
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Optional
import ygoenv
import numpy as np
import tyro
import torch
torch.set_num_threads(2)
import torch.optim as optim
import torch.distributed as dist
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import fprint
from ygoai.rl.buffer import create_obs, get_obs_shape
from ygoai.rl.ppo import bootstrap_value_selfplay_np as bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 256
"the number of parallel game environments"
local_env_threads: Optional[int] = None
"the number of threads to use for environment"
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
local_minibatch_size: int = 4096
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
minibatch_size: int = 0
num_envs: int = 0
batch_size: int = 0
num_iterations: int = 0
world_size: int = 0
num_embeddings: Optional[int] = None
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def _mp_fn(index, world_size):
rank = index
local_rank = index
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.num_envs = args.local_num_envs * args.world_size
args.local_batch_size = args.local_num_envs * args.num_steps
args.minibatch_size = args.local_minibatch_size * args.world_size
args.batch_size = args.num_envs * args.num_steps
args.num_iterations = args.total_timesteps // args.batch_size
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.env_threads = args.local_env_threads * args.world_size
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.local_batch_size % args.local_minibatch_size == 0, "local_batch_size must be divisible by local_minibatch_size"
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
torch.set_num_threads(2)
# torch.set_float32_matmul_precision('high')
if args.world_size > 1:
dist.init_process_group('xla', init_method='xla://')
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
# if args.torch_deterministic:
# torch.backends.cudnn.deterministic = True
# else:
# torch.backends.cudnn.benchmark = True
device = xm.xla_device()
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.local_num_envs, args.local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // args.local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
if args.fix_target:
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
# if args.world_size > 1:
# ddp_agent = DDP(agent, gradient_as_bucket_view=True)
# else:
# ddp_agent = agent
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
logits, value, valid = agent(next_obs)
return logits, value
from ygoai.rl.ppo import train_step_t as train_step
if args.compile:
traced_model = torch.compile(agent, backend='openxla_eval')
traced_model_t = traced_model
train_step = torch.compile(train_step, backend='openxla')
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup
obs_shape = get_obs_shape(obs_space)
obs = {
key: np.zeros(
(args.collect_length, args.local_num_envs, *_obs_shape), dtype=obs_space[key].dtype)
for key, _obs_shape in obs_shape.items()
}
actions = np.zeros((args.collect_length, args.local_num_envs) + action_shape, dtype=np.int64)
logprobs = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
rewards = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
dones = np.zeros((args.collect_length, args.local_num_envs), dtype=np.bool_)
values = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
learns = np.zeros((args.collect_length, args.local_num_envs), dtype=np.bool_)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs_ = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
next_value1 = next_value2 = 0
step = 0
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
o_time1 = 0
o_time2 = 0
collect_start = time.time()
while step < args.collect_length:
global_step += args.num_envs
_start = time.time()
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
o_time1 += time.time() - _start
_start = time.time()
logits, value = predict_step(traced_model, next_obs_)
if args.fix_target:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
u = torch.rand_like(logits)
action = torch.argmax(logits - torch.log(-torch.log(u)), dim=1)
logprob = logits.log_softmax(dim=1).gather(-1, action[:, None]).squeeze(-1)
value = value.flatten()
xm.mark_step()
model_time += time.time() - _start
_start = time.time()
logprob = logprob.cpu().numpy()
value = value.cpu().numpy()
action = action.cpu().numpy()
o_time2 += time.time() - _start
_start = time.time()
values[step] = value
actions[step] = action
logprobs[step] = logprob
next_nonterminal = 1 - next_done.astype(np.float32)
next_value1 = np.where(learn, value, next_value1) * next_nonterminal
next_value2 = np.where(learn, next_value2, value) * next_nonterminal
o_time1 += time.time() - _start
_start = time.time()
to_play = next_to_play
next_obs, reward, next_done, info = envs.step(action)
next_to_play = info["to_play"]
env_time += time.time() - _start
_start = time.time()
rewards[step] = reward
o_time1 += time.time() - _start
_start = time.time()
next_obs_ = to_tensor(next_obs, device, torch.uint8)
o_time2 += time.time() - _start
step += 1
if not writer:
continue
for idx, d in enumerate(next_done):
if d:
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
# if local_rank == 0:
fprint(f"[Rank {rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}, o_time1={o_time1:.4f}, o_time2={o_time2:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = predict_step(traced_model, next_obs_)[1].reshape(-1)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs_)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
value = value.cpu().numpy()
nextvalues1 = np.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = np.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - _start
train_start = time.time()
d_time1 = 0
d_time2 = 0
d_time3 = 0
# flatten the batch
b_obs = {
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.fix_target:
b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = np.ones_like(b_values, dtype=np.bool_)
_start = time.time()
b_obs = to_tensor(b_obs, device=device, dtype=torch.uint8)
b_actions, b_logprobs, b_advantages, b_values, b_returns, b_learns = [
to_tensor(v, device) for v in [b_actions, b_logprobs, b_advantages, b_values, b_returns, b_learns]
]
d_time1 += time.time() - _start
agent.train()
model_time = 0
# Optimizing the policy and value network
clipfracs = []
b_inds = np.arange(args.local_batch_size)
xm.mark_step()
for epoch in range(args.update_epochs):
_start = time.time()
np.random.shuffle(b_inds)
d_time2 += time.time() - _start
_start = time.time()
b_inds_ = to_tensor(b_inds, device=device)
n_mini_batches = args.local_batch_size // args.local_minibatch_size
b_inds_ = b_inds_.reshape(n_mini_batches, args.local_minibatch_size)
xm.mark_step()
d_time3 += time.time() - _start
for i in range(n_mini_batches):
_start = time.time()
mb_inds = b_inds_[i]
xm.mark_step()
d_time3 += time.time() - _start
_start = time.time()
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages,
b_returns, b_values, b_learns, mb_inds, args)
clipfracs.append(clipfrac)
xm.mark_step()
model_time += time.time() - _start
# mb_obs = {
# k: v[mb_inds] for k, v in b_obs.items()
# }
# mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns = [
# v[mb_inds] for v in [b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns]]
# xm.mark_step()
# old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
# train_step(ddp_agent_t, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages,
# mb_returns, mb_values, mb_learns, args)
# if rank == 0:
# # For short report that only contains a few key metrics.
# print(met.short_metrics_report())
# # For full report that includes all metrics.
# print(met.metrics_report())
# met.clear_all()
clipfrac = torch.stack(clipfracs).mean().item()
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - train_start
if local_rank == 0:
fprint(f"d_time1={d_time1:.4f}, d_time2={d_time2:.4f}, d_time3={d_time3:.4f}")
fprint(f"train_time={train_time:.4f}, model_time={model_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", clipfrac, global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 5
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if args.fix_target:
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
# if args.eval_interval and iteration % args.eval_interval == 0:
# # Eval with rule-based policy
# _start = time.time()
# eval_return = evaluate(
# eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
# eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# # sync the statistics
# if args.world_size > 1:
# dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
# eval_return = eval_stats.cpu().numpy()
# if rank == 0:
# writer.add_scalar("charts/eval_return", eval_return, global_step)
# if local_rank == 0:
# eval_time = time.time() - _start
# fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1:
_mp_fn(0, 1)
else:
xmp.spawn(_mp_fn, args=(world_size,))
......@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
......@@ -165,11 +163,17 @@ class Encoder(nn.Module):
for i in range(num_action_layers)
])
self.action_history_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_history_action_layers)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
......@@ -287,6 +291,7 @@ class Encoder(nn.Module):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
x_cards_1 = x_cards[:, :, :12].long()
x_cards_2 = x_cards[:, :, 12:].to(torch.float32)
......@@ -294,7 +299,10 @@ class Encoder(nn.Module):
x_id = self.encode_card_id(x_cards_1[:, :, :2])
x_id = self.id_norm(x_id)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 2]))
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask[:, 0] = False
f_loc = self.loc_norm(self.loc_embed(x_loc))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3]))
x_feat1 = self.encode_card_feat1(x_cards_1)
......@@ -306,11 +314,14 @@ class Encoder(nn.Module):
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
for layer in self.card_net:
# f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_na_card = self.na_card_embed.expand(batch_size, -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
# TODO: we can't use it because cudagraph says complex memory
# c_mask = torch.cat([torch.zeros(batch_size, 1, dtype=c_mask.dtype, device=c_mask.device), c_mask], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
......@@ -332,23 +343,27 @@ class Encoder(nn.Module):
mask = x_actions[:, :, 2] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
mask[:, 0] = False
# mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = layer(
f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask[:, 0] = False
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.history_action_pe(f_h_actions)
for layer in self.history_action_net:
f_h_actions = layer(f_h_actions, src_key_padding_mask=h_mask)
for layer in self.action_history_net:
f_actions = layer(
f_actions, f_h_actions, tgt_key_padding_mask=mask, memory_key_padding_mask=h_mask)
f_actions = self.action_norm(f_actions)
......@@ -385,13 +400,12 @@ class Actor(nn.Module):
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False,
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
channels, num_card_layers, num_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = Actor(c, a_trans)
......
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
return self.update_stats_and_infos(*super().step(action))
def update_stats_and_infos(self, *args):
observations, rewards, terminated, truncated, infos = args
dones = np.logical_or(terminated, truncated)
self.episode_returns += infos.get("reward", rewards)
self.episode_lengths += 1
self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns
)
self.returned_episode_lengths = np.where(
dones, self.episode_lengths, self.returned_episode_lengths
)
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return (
observations,
rewards,
dones,
infos,
)
def async_reset(self):
self.env.async_reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def recv(self):
return self.update_stats_and_infos(*self.env.recv())
def send(self, action):
return self.env.send(action)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
\ No newline at end of file
from functools import partial
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
q_estimate: jnp.ndarray
errors: jnp.ndarray
def vtrace(
v_tm1,
v_t,
r_t,
discount_t,
rho_tm1,
lambda_=1.0,
c_clip_min: float = 0.001,
c_clip_max: float = 1.007,
rho_clip_min: float = 0.001,
rho_clip_max: float = 1.007,
stop_target_gradients: bool = True,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_ = jnp.ones_like(discount_t) * lambda_
c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# Compute the temporal difference errors.
td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# Work backwards computing the td-errors.
def _body(acc, xs):
td_error, discount, c = xs
acc = td_error + discount * c * acc
return acc, acc
_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors = jax.lax.select(
stop_target_gradients,
jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
errors)
targets_tm1 = errors + v_tm1
q_bootstrap = jnp.concatenate([
lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
v_t[-1:],
], axis=0)
q_estimate = r_t + discount_t * q_bootstrap
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def upgo_return(r_t, v_t, discount_t, stop_target_gradients: bool = True):
def _body(acc, xs):
r, v, q, discount = xs
acc = r + discount * jnp.where(q >= v, acc, v)
return acc, acc
# TODO: following alphastar, estimate q_t with one-step target
# It might be better to use network to estimate q_t
q_t = r_t[1:] + discount_t[1:] * v_t[1:] # q[:-1]
_, returns = jax.lax.scan(
_body, q_t[-1], (r_t[:-1], v_t[:-1], q_t, discount_t[:-1]), reverse=True)
# Following rlax.vtrace_td_error_and_advantage, part of gradient is reserved
# Experiments show that where to stop gradient has no impact on the performance
returns = jax.lax.select(
stop_target_gradients, jax.lax.stop_gradient(returns), returns)
returns = jnp.concatenate([returns, q_t[-1:]], axis=0)
return returns
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
gamma_ = gamma * (1.0 - next_done)
delta = reward + gamma_ * next_value - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, lastgaelam
_, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_upgo_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, next_q, last_return, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
lastgaelam = jnp.where(switch, 0, lastgaelam)
gamma_ = gamma * (1.0 - next_done)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
return returns - values, advantages + values
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
target_values = advantages + values
return advantages, target_values
def compute_gae_once_upgo(carry, inp, gamma, gae_lambda):
next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
next_value1 = jnp.where(real_done1, 0, next_value1)
last_return1 = jnp.where(real_done1, 0, last_return1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
next_value2 = jnp.where(real_done2, 0, next_value2)
last_return2 = jnp.where(real_done2, 0, last_return2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2)
next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2
delta1 = next_q1_ - curvalues
delta2 = next_q2_ - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
returns = jnp.where(learn1, last_return1_, last_return2_)
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
next_value1 = jnp.where(learn1, curvalues, next_value1)
next_value2 = jnp.where(learn2, curvalues, next_value2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
next_q1 = jnp.where(learn1, next_q1_, next_q1)
next_q2 = jnp.where(learn2, next_q2_, next_q1)
last_return1 = jnp.where(learn1, last_return1_, last_return1)
last_return2 = jnp.where(learn2, last_return2_, last_return2)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, (advantages, returns)
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae_upgo(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, (advantages, returns) = jax.lax.scan(
partial(compute_gae_once_upgo, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
return returns - values, advantages + values
from typing import Tuple, Union, Optional
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1)
class Encoder(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
fc_layer = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = MLP((c // 8,), last_lin=False, dtype=jnp.float32, param_dtype=self.param_dtype)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32)
x_cards_2 = x_cards[:, :, 12:].astype(jnp.float32)
x_id = decode_id(x_cards_1[:, :, :2])
x_id = id_embed(x_id)
x_id = MLP(
(c, c // 4), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x_cards_1[:, :, 3]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x_cards_1[:, :, 4])
x_position = embed(9, c // 16)(x_cards_1[:, :, 5])
x_overley = embed(2, c // 16)(x_cards_1[:, :, 6])
x_attribute = embed(8, c // 16)(x_cards_1[:, :, 7])
x_race = embed(27, c // 16)(x_cards_1[:, :, 8])
x_level = embed(14, c // 16)(x_cards_1[:, :, 9])
x_counter = embed(16, c // 16)(x_cards_1[:, :, 10])
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11])
x_atk = num_transform(x_cards_2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
x_feat = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_feat = layer_norm()(x_feat)
f_cards = jnp.concatenate([x_id, x_feat], axis=-1)
f_cards = f_cards + f_loc + f_seq
num_heads = max(2, c // 128)
for _ in range(self.num_card_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(jnp.float32)
x_g_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0])
x_g_phase = embed(11, c // 8)(x_global_2[:, 1])
x_g_if_first = embed(2, c // 8)(x_global_2[:, 2])
x_g_is_my_turn = embed(2, c // 8)(x_global_2[:, 3])
x_global_3 = x_global[:, 8:22].astype(jnp.int32)
x_g_cs = count_embed(x_global_3).reshape((batch_size, -1))
x_g_my_hand_c = hand_count_embed(x_global_3[:, 1])
x_g_op_hand_c = hand_count_embed(x_global_3[:, 8])
x_global = jnp.concatenate([
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
f_cards = f_cards + jnp.expand_dims(f_global, 1)
x_actions = x_actions.astype(jnp.int32)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = f_a_cards + fc_layer(c)(layer_norm()(f_a_cards))
x_a_feats = action_encoder(x_actions[..., 2:])
f_actions = f_a_cards + layer_norm()(x_a_feats)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, f_cards,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_'].astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
f_h_actions = layer_norm()(x_h_id) + layer_norm()(x_h_a_feats)
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_action_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(
f_actions, f_h_actions,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask)
f_actions = layer_norm()(f_actions)
f_s_cards_global = f_cards.mean(axis=1)
c_mask = 1 - a_mask[:, :, None].astype(f_actions.dtype)
f_s_actions_ha = (f_actions * c_mask).sum(axis=1) / c_mask.sum(axis=1)
f_state = jnp.concatenate([f_s_cards_global, f_s_actions_ha], axis=-1)
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_actions, mask):
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
num_heads = max(2, c // 128)
f_actions = EncoderLayer(
num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = mlp((c // 4, 1), use_bias=True)(f_actions)
logits = logits[..., 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(1.0))
x = MLP((c // 2, 1), use_bias=True)(f_state)
return x
class PPOAgent(nn.Module):
channels: int = 128
num_card_layers: int = 2
num_action_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
encoder = Encoder(
channels=self.channels,
num_card_layers=self.num_card_layers,
num_action_layers=self.num_action_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
critic = Critic(channels=self.channels, dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_actions, mask)
value = critic(f_state)
return logits, value, valid
from typing import Tuple, Union, Optional, Sequence
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
xs = [x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib]
return xs
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x_id, x):
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0]
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
return f_cards
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = layer_norm()(x)
return x
class Encoder(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
# Cards
f_cards = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats1 = action_encoder(x_h_actions[:, :, 2:13])
x_h_a_player = embed(2, c // 2)(x_h_actions[:, :, 13])
x_h_a_turn = embed(20, c // 2)(x_h_actions[:, :, 14])
x_h_a_feats = jnp.concatenate([
*x_h_a_feats1, x_h_a_player, x_h_a_turn], axis=-1)
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = jnp.concatenate(action_encoder(x_actions[..., 2:]), axis=-1)
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
# TODO: LSTM
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
class PPOAgent(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return logits, value, valid
class PPOLSTMAgent(nn.Module):
channels: int = 128
num_layers: int = 2
lstm_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
@nn.compact
def __call__(self, inputs):
if self.multi_step:
# (num_steps * batch_size, ...)
carry1, carry2, x, done, switch = inputs
batch_size = carry1[0].shape[0]
num_steps = done.shape[0] // batch_size
else:
carry, x = inputs
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
f_actions, f_state, mask, valid = encoder(x)
lstm_layer = nn.OptimizedLSTMCell(
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.multi_step:
def body_fn(cell, carry, x, done, switch):
carry, init_carry = carry
carry, y = cell(carry, x)
carry = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), carry)
carry = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_carry, carry)
return (carry, init_carry), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state, done, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch))
carry, f_state = scan(lstm_layer, (carry1, carry2), f_state, done, switch)
f_state = f_state.reshape((-1, f_state.shape[-1]))
else:
carry, f_state = lstm_layer(carry, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return carry, logits, value, valid
import numpy as np
def evaluate(envs, act_fn, params, rnn_state=None):
num_episodes = envs.num_envs
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_)
while True:
if rnn_state is None:
actions = act_fn(params, obs)
else:
rnn_state, actions = act_fn(params, (rnn_state, obs))
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if not d or collected[idx]:
continue
collected[idx] = True
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
from typing import Tuple, Union, Optional
import jax.numpy as jnp
import flax.linen as nn
def decode_id(x):
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=12000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
if self.last_lin and i == n - 1:
kernel_init = self.last_kernel_init
else:
kernel_init = self.kernel_init
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=kernel_init, use_bias=self.use_bias)(x)
if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1)
return x
import functools
from typing import Callable, Optional, Sequence, Union, Dict, Any
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from flax.linen.dtypes import promote_dtype
Array = Union[jax.Array, Any]
PRNGKey = jax.Array
RNGSequences = Dict[str, PRNGKey]
Dtype = Union[jax.typing.DTypeLike, Any]
Shape = Sequence[int]
PrecisionLike = Union[jax.lax.Precision, str]
default_kernel_init = nn.initializers.lecun_normal()
default_bias_init = nn.initializers.zeros
class RMSNorm(nn.Module):
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
dtype = jnp.promote_types(self.dtype, jnp.float32)
x = jnp.asarray(x, dtype)
x = x * jax.lax.rsqrt(jnp.square(x).mean(-1,
keepdims=True) + self.epsilon)
reduced_feature_shape = (x.shape[-1],)
scale = self.param(
"scale", nn.initializers.ones, reduced_feature_shape, self.param_dtype
)
x = x * scale
return jnp.asarray(x, self.dtype)
def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
"""1D Sinusoidal Position Embedding Initializer.
Args:
max_len: maximum possible length for the input.
min_scale: float: minimum frequency-scale in sine grating.
max_scale: float: maximum frequency-scale in sine grating.
Returns:
output: init function returning `(1, max_len, d_feature)`
"""
def init(key, shape, dtype=np.float32):
"""Sinusoidal init."""
del key, dtype
d_feature = shape[-1]
pe = np.zeros((max_len, d_feature), dtype=np.float32)
position = np.arange(0, max_len)[:, np.newaxis]
scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
div_term = min_scale * \
np.exp(np.arange(0, d_feature // 2) * scale_factor)
pe[:, : d_feature // 2] = np.sin(position * div_term)
pe[:, d_feature // 2: 2 * (d_feature // 2)
] = np.cos(position * div_term)
pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
return jnp.array(pe)
return init
class PositionalEncoding(nn.Module):
"""Adds (optionally learned) positional embeddings to the inputs.
"""
max_len: int = 512
learned: bool = False
@nn.compact
def __call__(self, inputs):
"""Applies AddPositionEmbs module.
By default this layer uses a fixed sinusoidal embedding table. If a
learned position embedding is desired, pass an initializer to
posemb_init in the configuration.
Args:
inputs: input data.
Returns:
output: `(bs, timesteps, in_dim)`
"""
# inputs.shape is (batch_size, seq_len, emb_dim)
assert inputs.ndim == 3, (
'Number of dimensions should be 3, but it is: %d' % inputs.ndim
)
length = inputs.shape[1]
pos_emb_shape = (1, self.max_len, inputs.shape[-1])
initializer = sinusoidal_init(max_len=self.max_len)
if self.learned:
pos_embedding = self.param(
'pos_embedding', initializer, pos_emb_shape
)
else:
pos_embedding = initializer(
None, pos_emb_shape, None
)
pe = pos_embedding[:, :length, :]
return inputs + pe
def precompute_freqs_cis(
dim: int, end: int, theta=10000.0, dtype=jnp.float32
):
# returns:
# cos, sin: (end, dim)
freqs = 1.0 / \
(theta ** (np.arange(0, dim, 2, dtype=np.float32)[: (dim // 2)] / dim))
t = np.arange(end, dtype=np.float32) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
freqs = np.concatenate((freqs, freqs), axis=-1)
cos, sin = np.cos(freqs), np.sin(freqs)
return jnp.array(cos, dtype=dtype), jnp.array(sin, dtype=dtype)
# from chatglm2, different from original rope
def precompute_freqs_cis2(
dim: int, end: int, theta: float = 10000.0, dtype=jnp.float32
):
# returns:
# cos, sin: (end, dim)
freqs = 1.0 / \
(theta ** (np.arange(0, dim, 2, dtype=np.float32)[: (dim // 2)] / dim))
t = np.arange(end, dtype=np.float32) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
cos, sin = np.cos(freqs), np.sin(freqs)
return jnp.array(cos, dtype=dtype), jnp.array(sin, dtype=dtype)
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id=None):
# inputs:
# x: (batch_size, seq_len, num_heads, head_dim)
# cos, sin: (seq_len, head_dim)
# position_id: (batch_size, seq_len)
# returns:
# x: (batch_size, seq_len, num_heads, head_dim)
if position_id is None:
q_pos = jnp.arange(q.shape[1])[None, :]
k_pos = jnp.arange(k.shape[1])[None, :]
else:
q_pos = position_id
k_pos = position_id
cos_q = jnp.take(cos, q_pos, axis=0)[:, :, None, :]
sin_q = jnp.take(sin, q_pos, axis=0)[:, :, None, :]
q = (q * cos_q) + (rotate_half(q) * sin_q)
cos_k = jnp.take(cos, k_pos, axis=0)[:, :, None, :]
sin_k = jnp.take(sin, k_pos, axis=0)[:, :, None, :]
k = (k * cos_k) + (rotate_half(k) * sin_k)
return q, k
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return jnp.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb_index2(q, k, cos, sin, position_id=None):
# inputs:
# x: (batch_size, seq_len, num_heads, head_dim)
# cos, sin: (seq_len, head_dim)
# position_id: (batch_size, seq_len)
# returns:
# x: (batch_size, seq_len, num_heads, head_dim)
if position_id is None:
q_pos = jnp.arange(q.shape[1])[None, :]
k_pos = jnp.arange(k.shape[1])[None, :]
else:
q_pos = position_id
k_pos = position_id
cos_q = jnp.take(cos, q_pos, axis=0)[:, :, None, :]
sin_q = jnp.take(sin, q_pos, axis=0)[:, :, None, :]
q = apply_cos_sin(q, cos_q, sin_q)
cos_k = jnp.take(cos, k_pos, axis=0)[:, :, None, :]
sin_k = jnp.take(sin, k_pos, axis=0)[:, :, None, :]
k = apply_cos_sin(k, cos_k, sin_k)
return q, k
def apply_cos_sin(x, cos, sin):
dim = x.shape[-1]
x1 = x[..., :dim // 2]
x2 = x[..., dim // 2:]
x1 = x1.reshape(x1.shape[:-1] + (-1, 2))
x1 = jnp.stack((x1[..., 0] * cos - x1[..., 1] * sin,
x1[..., 1] * cos + x1[..., 0] * sin), axis=-1)
x1 = x1.reshape(x2.shape)
x = jnp.concatenate((x1, x2), axis=-1)
return x
def make_apply_rope(head_dim, max_len, dtype, multi_query=False):
if multi_query:
cos, sin = precompute_freqs_cis2(
dim=head_dim // 2, end=max_len, dtype=dtype)
def add_pos(q, k, p=None): return apply_rotary_pos_emb_index2(
q, k, cos, sin, p)
else:
cos, sin = precompute_freqs_cis(
dim=head_dim, end=max_len, dtype=dtype)
def add_pos(q, k, p=None): return apply_rotary_pos_emb_index(
q, k, cos, sin, p)
return add_pos
def replicate_for_multi_query(x, num_heads):
src_num_heads, head_dim = x.shape[-2:]
x = jnp.repeat(x, num_heads // src_num_heads, axis=-2)
# x = jnp.expand_dims(x, axis=-2)
# x = jnp.tile(x, (1, 1, 1, num_heads // src_num_heads, 1))
# x = jnp.reshape(x, (*x.shape[:2], num_heads, head_dim))
return x
def dot_product_attention_weights(
query: Array,
key: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention weights given query and key.
Used by :func:`dot_product_attention`, which is what you'll most likely use.
But if you want access to the attention weights for introspection, then
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``True``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
Returns:
Output of shape ``[batch..., num_heads, q_length, kv_length]``.
"""
query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype
assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, big_neg, attn_weights)
# normalize the attention weights
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + \
attn_weights.shape[-2:]
keep = random.bernoulli(
dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(
dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
attn_weights = attn_weights * multiplier
return attn_weights
def dot_product_attention(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
value: values to be used in attention with shape of ``[batch..., kv_length,
num_heads, v_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``True``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see ``jax.lax.Precision`
for details.
Returns:
Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
"""
query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
attn_weights = dot_product_attention_weights(
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
deterministic,
dtype,
precision,
)
# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
)
class MultiheadAttention(nn.Module):
features: int
num_heads: int
max_len: Optional[int] = None
multi_query_groups: Optional[int] = None
dtype: Optional[Dtype] = None
param_dtype: Optional[Dtype] = jnp.float32
broadcast_dropout: bool = False
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
qkv_bias: bool = True
out_bias: bool = True
rope: bool = False
@nn.compact
def __call__(
self,
query: Array,
key: Array,
value: Array,
key_padding_mask: Optional[Array] = None,
attn_mask: Optional[Array] = None,
):
r"""
Parameters
----------
query: Array, shape [batch, q_len, features]
Query features.
key: Array, shape [batch, kv_len, features]
Key features.
value: Array, shape [batch, kv_len, features]
Value features.
key_padding_mask: Optional[Array], shape [batch, kv_len]
Mask to indicate which keys have zero padding.
attn_mask: Optional[Array], shape [batch, 1, q_len, kv_len]
Mask to apply to attention scores.
Returns
-------
out: Array, shape [batch, q_len, features]
Output features.
"""
features = self.features
if self.rope:
assert self.max_len is not None, "max_len must be provided for rope"
multi_query = self.multi_query_groups is not None
assert (
features % self.num_heads == 0
), "Memory dimension must be divisible by number of heads."
head_dim = features // self.num_heads
query = nn.DenseGeneral(
features=(self.num_heads, head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.qkv_bias,
axis=-1,
name="query",
)(query)
kv_num_heads = self.num_heads
if multi_query:
kv_num_heads = self.multi_query_groups
kv_dense = [
functools.partial(
nn.DenseGeneral,
features=(kv_num_heads, head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.qkv_bias,
axis=-1,
) for i in range(2)
]
key = kv_dense[0](name="key")(key)
value = kv_dense[1](name="value")(value)
if multi_query:
key = replicate_for_multi_query(key, self.num_heads)
value = replicate_for_multi_query(value, self.num_heads)
if self.rope:
add_pos = make_apply_rope(
head_dim, self.max_len, self.dtype, multi_query)
else:
def add_pos(q, k, p=None): return (q, k)
query, key = add_pos(query, key)
dropout_rng = None
if self.dropout_rate > 0 and not self.deterministic:
dropout_rng = self.make_rng("dropout")
deterministic = False
else:
deterministic = True
if key_padding_mask is not None:
key_padding_mask = key_padding_mask[:, None, None, :]
if attn_mask is not None:
mask = attn_mask
if key_padding_mask is not None:
mask = jnp.logical_or(mask, key_padding_mask)
else:
mask = key_padding_mask
x = dot_product_attention(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=deterministic,
dtype=self.dtype,
)
out = nn.DenseGeneral(
features=features,
axis=(-2, -1),
use_bias=self.out_bias,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dtype=self.dtype,
param_dtype=self.param_dtype,
name="out",
)(x)
return out
class MlpBlock(nn.Module):
intermediate_size: Optional[int] = None
activation: str = "gelu"
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
use_bias: bool = True
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
@nn.compact
def __call__(self, inputs):
assert self.activation in [
"gelu", "gelu_new", "relu"], "activation must be gelu, gelu_new or relu"
intermediate_size = self.intermediate_size or 4 * inputs.shape[-1]
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(2)
]
actual_out_dim = inputs.shape[-1]
x = dense[0](
features=intermediate_size,
name="fc_1",
)(inputs)
if self.activation == "gelu":
x = nn.gelu(x, approximate=False)
elif self.activation == "gelu_new":
x = nn.gelu(x, approximate=True)
elif self.activation == "relu":
x = nn.relu(x)
x = dense[1](
features=actual_out_dim,
name="fc_2",
)(x)
return x
class GLUMlpBlock(nn.Module):
intermediate_size: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
use_bias: bool = False
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_bias_init
@nn.compact
def __call__(self, inputs):
dense = [
functools.partial(
nn.DenseGeneral,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
shard=self.shard,
) for _ in range(3)
]
actual_out_dim = inputs.shape[-1]
g = dense[0](
features=self.intermediate_size,
name="gate",
)(inputs)
g = nn.silu(g)
x = g * dense[1](
features=self.intermediate_size,
name="up",
)(inputs)
x = dense[2](
features=actual_out_dim,
name="down",
)(x)
return x
class EncoderLayer(nn.Module):
n_heads: int
intermediate_size: Optional[int] = None
activation: str = "relu"
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
layer_norm_epsilon: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
inputs = jnp.asarray(inputs, self.dtype)
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
features=x.shape[-1],
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + inputs
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_2")(x)
y = MlpBlock(
intermediate_size=self.intermediate_size,
activation=self.activation,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = x + y
return y
class DecoderLayer(nn.Module):
n_heads: int
intermediate_size: Optional[int] = None
activation: str = "relu"
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
layer_norm_epsilon: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
features = tgt.shape[-1]
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(tgt)
x = MultiheadAttention(
features=features,
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + tgt
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_2")(x)
y = MultiheadAttention(
features=features,
num_heads=self.n_heads,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = y + x
z = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_3")(y)
z = MlpBlock(
intermediate_size=self.intermediate_size,
activation=self.activation,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="mlp")(z)
z = nn.Dropout(rate=self.resid_pdrop)(
z, deterministic=self.deterministic
)
z = y + z
return z
class LlamaEncoderLayer(nn.Module):
n_heads: int
intermediate_size: int
n_positions: int = 512
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
rms_norm_eps: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
features=x.shape[-1],
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + inputs
y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x)
y = GLUMlpBlock(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
use_bias=False,
name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = x + y
return y
class LlamaDecoderLayer(nn.Module):
n_heads: int
intermediate_size: int
n_positions: int = 512
dtype: Any = None
param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
rms_norm_eps: float = 1e-6
kernel_init: Callable = default_kernel_init
bias_init: Callable = default_bias_init
deterministic: bool = True
@nn.compact
def __call__(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
features = tgt.shape[-1]
x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(tgt)
x = MultiheadAttention(
features=features,
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic)
x = x + tgt
y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x)
y = MultiheadAttention(
features=features,
num_heads=self.n_heads,
max_len=self.n_positions,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
qkv_bias=False,
out_bias=False,
rope=True,
dropout_rate=self.attn_pdrop,
deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic)
y = y + x
z = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_3")(y)
z = GLUMlpBlock(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
use_bias=False,
name="mlp")(z)
z = nn.Dropout(rate=self.resid_pdrop)(
z, deterministic=self.deterministic
)
z = y + z
return z
import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x))
return x.sum() / valid.sum()
def masked_normalize(x, valid, epsilon=1e-8):
x = jnp.where(valid, x, jnp.zeros_like(x))
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
\ No newline at end of file
import numpy as np
import torch
from torch.distributions import Categorical
from torch.cuda.amp import autocast
import torch_xla.core.xla_model as xm
from ygoai.rl.utils import masked_normalize, masked_mean
def entropy_from_logits(logits):
min_real = torch.finfo(logits.dtype).min
logits = torch.clamp(logits, min=min_real)
p_log_p = logits * torch.softmax(logits, dim=-1)
return -p_log_p.sum(-1)
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
if not args.learn_opponent:
valid = torch.logical_and(valid, mb_learns)
logits, newvalue, valid = agent(mb_obs)[:3]
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
......@@ -50,11 +58,130 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if scaler is None:
loss.backward()
else:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns, mb_inds, args):
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns = [
v[mb_inds] for v in [b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns]]
optimizer.zero_grad(True)
logits, newvalue, valid = agent(mb_obs)
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
loss.backward()
xm.optimizer_step(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
# logits, newvalue, valid = agent(mb_obs)
# logits = logits - logits.logsumexp(dim=-1, keepdim=True)
# newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
# entropy = entropy_from_logits(logits)
# valid = torch.logical_and(valid, mb_learns)
# logratio = newlogprob - mb_logprobs
# ratio = logratio.exp()
# with torch.no_grad():
# # calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
# approx_kl = ((ratio - 1) - logratio).mean()
# clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
# if args.norm_adv:
# mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# # Policy loss
# pg_loss1 = -mb_advantages * ratio
# pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# pg_loss = torch.max(pg_loss1, pg_loss2)
# pg_loss = masked_mean(pg_loss, valid)
# # Value loss
# newvalue = newvalue.view(-1)
# if args.clip_vloss:
# v_loss_unclipped = (newvalue - mb_returns) ** 2
# v_clipped = mb_values + torch.clamp(
# newvalue - mb_values,
# -args.clip_coef,
# args.clip_coef,
# )
# v_loss_clipped = (v_clipped - mb_returns) ** 2
# v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
# v_loss = 0.5 * v_loss_max
# else:
# v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
# v_loss = masked_mean(v_loss, valid)
# entropy_loss = masked_mean(entropy, valid)
# loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
# loss.backward()
# optimizer.step()
# return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = nextvalues
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
......@@ -190,4 +317,115 @@ def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextva
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
def bootstrap_value_selfplay_upgo(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# next_values1 = 0
# last_return1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# next_values1 = 0
# last_return1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# last_return1_ = reward1 + args.gamma * (last_return1 if (next_qs1 >= next_values1) else next_values1)
# next_q1_ = reward1 + args.gamma * next_values1
# delta1 = next_q1_ - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# returns[t] = last_return1_
# advantages[t] = lastgaelam1_
# next_values1 = values[t]
# lastgaelam1 = lastgaelam1_
# next_qs1 = next_q1_
# last_return1 = last_return1_
# else:
# Skip because it is symmetric
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
def bootstrap_value_selfplay_np(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.shape[0]
advantages = np.zeros_like(rewards)
# TODO: optimize this
done_used1 = np.ones_like(next_done, dtype=np.bool_)
done_used2 = np.ones_like(next_done, dtype=np.bool_)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.astype(np.float32) - 0.5)
reward1 = np.where(next_done, rewards[t] * sp, np.where(learn1 & done_used1, 0, reward1))
reward2 = np.where(next_done, rewards[t] * -sp, np.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = np.where(real_done1, 0, nextvalues1)
lastgaelam1 = np.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = np.where(real_done2, 0, nextvalues2)
lastgaelam2 = np.where(real_done2, 0, lastgaelam2)
done_used1 = np.where(
next_done, learn1, np.where(learn1 & ~done_used1, True, done_used1))
done_used2 = np.where(
next_done, learn2, np.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = np.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = np.where(learn1, values[t], nextvalues1)
nextvalues2 = np.where(learn2, values[t], nextvalues2)
lastgaelam1 = np.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = np.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
\ No newline at end of file
......@@ -6,55 +6,7 @@ import pickle
import optree
import torch
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
from ygoai.rl.env import RecordEpisodeStatistics
def split_param_groups(model, regex):
......@@ -103,7 +55,7 @@ def masked_normalize(x, valid, eps=1e-8):
return (x - mean) / std
def to_tensor(x, device, dtype=torch.float32):
def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
......
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
#ifndef BS_THREAD_POOL_HPP
#define BS_THREAD_POOL_HPP
/**
* @file BS_thread_pool.hpp
* @author Barak Shoshany (baraksh@gmail.com) (https://baraksh.com)
* @version 4.1.0
* @date 2024-03-22
* @copyright Copyright (c) 2024 Barak Shoshany. Licensed under the MIT license. If you found this project useful, please consider starring it on GitHub! If you use this library in software of any kind, please provide a link to the GitHub repository https://github.com/bshoshany/thread-pool in the source code and documentation. If you use this library in published research, please cite it as follows: Barak Shoshany, "A C++17 Thread Pool for High-Performance Scientific Computing", doi:10.1016/j.softx.2024.101687, SoftwareX 26 (2024) 101687, arXiv:2105.00613
*
* @brief BS::thread_pool: a fast, lightweight, and easy-to-use C++17 thread pool library. This header file contains the main thread pool class and some additional classes and definitions. No other files are needed in order to use the thread pool itself.
*/
#ifndef __cpp_exceptions
#define BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
#undef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
#endif
#include <chrono> // std::chrono
#include <condition_variable> // std::condition_variable
#include <cstddef> // std::size_t
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
#include <cstdint> // std::int_least16_t
#endif
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
#include <exception> // std::current_exception
#endif
#include <functional> // std::function
#include <future> // std::future, std::future_status, std::promise
#include <memory> // std::make_shared, std::make_unique, std::shared_ptr, std::unique_ptr
#include <mutex> // std::mutex, std::scoped_lock, std::unique_lock
#include <optional> // std::nullopt, std::optional
#include <queue> // std::priority_queue (if priority enabled), std::queue
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
#include <stdexcept> // std::runtime_error
#endif
#include <thread> // std::thread
#include <type_traits> // std::conditional_t, std::decay_t, std::invoke_result_t, std::is_void_v, std::remove_const_t (if priority enabled)
#include <utility> // std::forward, std::move
#include <vector> // std::vector
/**
* @brief A namespace used by Barak Shoshany's projects.
*/
namespace BS {
// Macros indicating the version of the thread pool library.
#define BS_THREAD_POOL_VERSION_MAJOR 4
#define BS_THREAD_POOL_VERSION_MINOR 1
#define BS_THREAD_POOL_VERSION_PATCH 0
class thread_pool;
/**
* @brief A type to represent the size of things.
*/
using size_t = std::size_t;
/**
* @brief A convenient shorthand for the type of `std::thread::hardware_concurrency()`. Should evaluate to unsigned int.
*/
using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
/**
* @brief A type used to indicate the priority of a task. Defined to be an integer with a width of (at least) 16 bits.
*/
using priority_t = std::int_least16_t;
/**
* @brief A namespace containing some pre-defined priorities for convenience.
*/
namespace pr {
constexpr priority_t highest = 32767;
constexpr priority_t high = 16383;
constexpr priority_t normal = 0;
constexpr priority_t low = -16384;
constexpr priority_t lowest = -32768;
} // namespace pr
// Macros used internally to enable or disable the priority arguments in the relevant functions.
#define BS_THREAD_POOL_PRIORITY_INPUT , const priority_t priority = 0
#define BS_THREAD_POOL_PRIORITY_OUTPUT , priority
#else
#define BS_THREAD_POOL_PRIORITY_INPUT
#define BS_THREAD_POOL_PRIORITY_OUTPUT
#endif
/**
* @brief A namespace used to obtain information about the current thread.
*/
namespace this_thread {
/**
* @brief A type returned by `BS::this_thread::get_index()` which can optionally contain the index of a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value.
*/
using optional_index = std::optional<size_t>;
/**
* @brief A type returned by `BS::this_thread::get_pool()` which can optionally contain the pointer to the pool that owns a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value.
*/
using optional_pool = std::optional<thread_pool*>;
/**
* @brief A helper class to store information about the index of the current thread.
*/
class [[nodiscard]] thread_info_index
{
friend class BS::thread_pool;
public:
/**
* @brief Get the index of the current thread. If this thread belongs to a `BS::thread_pool` object, it will have an index from 0 to `BS::thread_pool::get_thread_count() - 1`. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned.
*
* @return An `std::optional` object, optionally containing a thread index. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value.
*/
[[nodiscard]] optional_index operator()() const
{
return index;
}
private:
/**
* @brief The index of the current thread.
*/
optional_index index = std::nullopt;
}; // class thread_info_index
/**
* @brief A helper class to store information about the thread pool that owns the current thread.
*/
class [[nodiscard]] thread_info_pool
{
friend class BS::thread_pool;
public:
/**
* @brief Get the pointer to the thread pool that owns the current thread. If this thread belongs to a `BS::thread_pool` object, a pointer to that object will be returned. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned.
*
* @return An `std::optional` object, optionally containing a pointer to a thread pool. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value.
*/
[[nodiscard]] optional_pool operator()() const
{
return pool;
}
private:
/**
* @brief A pointer to the thread pool that owns the current thread.
*/
optional_pool pool = std::nullopt;
}; // class thread_info_pool
/**
* @brief A `thread_local` object used to obtain information about the index of the current thread.
*/
inline thread_local thread_info_index get_index;
/**
* @brief A `thread_local` object used to obtain information about the thread pool that owns the current thread.
*/
inline thread_local thread_info_pool get_pool;
} // namespace this_thread
/**
* @brief A helper class to facilitate waiting for and/or getting the results of multiple futures at once.
*
* @tparam T The return type of the futures.
*/
template <typename T>
class [[nodiscard]] multi_future : public std::vector<std::future<T>>
{
public:
// Inherit all constructors from the base class `std::vector`.
using std::vector<std::future<T>>::vector;
// The copy constructor and copy assignment operator are deleted. The elements stored in a `multi_future` are futures, which cannot be copied.
multi_future(const multi_future&) = delete;
multi_future& operator=(const multi_future&) = delete;
// The move constructor and move assignment operator are defaulted.
multi_future(multi_future&&) = default;
multi_future& operator=(multi_future&&) = default;
/**
* @brief Get the results from all the futures stored in this `multi_future`, rethrowing any stored exceptions.
*
* @return If the futures return `void`, this function returns `void` as well. Otherwise, it returns a vector containing the results.
*/
[[nodiscard]] std::conditional_t<std::is_void_v<T>, void, std::vector<T>> get()
{
if constexpr (std::is_void_v<T>)
{
for (std::future<T>& future : *this)
future.get();
return;
}
else
{
std::vector<T> results;
results.reserve(this->size());
for (std::future<T>& future : *this)
results.push_back(future.get());
return results;
}
}
/**
* @brief Check how many of the futures stored in this `multi_future` are ready.
*
* @return The number of ready futures.
*/
[[nodiscard]] size_t ready_count() const
{
size_t count = 0;
for (const std::future<T>& future : *this)
{
if (future.wait_for(std::chrono::duration<double>::zero()) == std::future_status::ready)
++count;
}
return count;
}
/**
* @brief Check if all the futures stored in this `multi_future` are valid.
*
* @return `true` if all futures are valid, `false` if at least one of the futures is not valid.
*/
[[nodiscard]] bool valid() const
{
bool is_valid = true;
for (const std::future<T>& future : *this)
is_valid = is_valid && future.valid();
return is_valid;
}
/**
* @brief Wait for all the futures stored in this `multi_future`.
*/
void wait() const
{
for (const std::future<T>& future : *this)
future.wait();
}
/**
* @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified duration has passed. This function first waits for the first future for the desired duration. If that future is ready before the duration expires, this function waits for the second future for whatever remains of the duration. It continues similarly until the duration expires.
*
* @tparam R An arithmetic type representing the number of ticks to wait.
* @tparam P An `std::ratio` representing the length of each tick in seconds.
* @param duration The amount of time to wait.
* @return `true` if all futures have been waited for before the duration expired, `false` otherwise.
*/
template <typename R, typename P>
bool wait_for(const std::chrono::duration<R, P>& duration) const
{
const std::chrono::time_point<std::chrono::steady_clock> start_time = std::chrono::steady_clock::now();
for (const std::future<T>& future : *this)
{
future.wait_for(duration - (std::chrono::steady_clock::now() - start_time));
if (duration < std::chrono::steady_clock::now() - start_time)
return false;
}
return true;
}
/**
* @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified time point has been reached. This function first waits for the first future until the desired time point. If that future is ready before the time point is reached, this function waits for the second future until the desired time point. It continues similarly until the time point is reached.
*
* @tparam C The type of the clock used to measure time.
* @tparam D An `std::chrono::duration` type used to indicate the time point.
* @param timeout_time The time point at which to stop waiting.
* @return `true` if all futures have been waited for before the time point was reached, `false` otherwise.
*/
template <typename C, typename D>
bool wait_until(const std::chrono::time_point<C, D>& timeout_time) const
{
for (const std::future<T>& future : *this)
{
future.wait_until(timeout_time);
if (timeout_time < std::chrono::steady_clock::now())
return false;
}
return true;
}
}; // class multi_future
/**
* @brief A fast, lightweight, and easy-to-use C++17 thread pool class.
*/
class [[nodiscard]] thread_pool
{
public:
// ============================
// Constructors and destructors
// ============================
/**
* @brief Construct a new thread pool. The number of threads will be the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
*/
thread_pool() : thread_pool(0, [] {}) {}
/**
* @brief Construct a new thread pool with the specified number of threads.
*
* @param num_threads The number of threads to use.
*/
explicit thread_pool(const concurrency_t num_threads) : thread_pool(num_threads, [] {}) {}
/**
* @brief Construct a new thread pool with the specified initialization function.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
explicit thread_pool(const std::function<void()>& init_task) : thread_pool(0, init_task) {}
/**
* @brief Construct a new thread pool with the specified number of threads and initialization function.
*
* @param num_threads The number of threads to use.
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
thread_pool(const concurrency_t num_threads, const std::function<void()>& init_task) : thread_count(determine_thread_count(num_threads)), threads(std::make_unique<std::thread[]>(determine_thread_count(num_threads)))
{
create_threads(init_task);
}
// The copy and move constructors and assignment operators are deleted. The thread pool uses a mutex, which cannot be copied or moved.
thread_pool(const thread_pool&) = delete;
thread_pool(thread_pool&&) = delete;
thread_pool& operator=(const thread_pool&) = delete;
thread_pool& operator=(thread_pool&&) = delete;
/**
* @brief Destruct the thread pool. Waits for all tasks to complete, then destroys all threads. Note that if the pool is paused, then any tasks still in the queue will never be executed.
*/
~thread_pool()
{
wait();
destroy_threads();
}
// =======================
// Public member functions
// =======================
#ifdef BS_THREAD_POOL_ENABLE_NATIVE_HANDLES
/**
* @brief Get a vector containing the underlying implementation-defined thread handles for each of the pool's threads, as obtained by `std::thread::native_handle()`. Only enabled if `BS_THREAD_POOL_ENABLE_NATIVE_HANDLES` is defined.
*
* @return The native thread handles.
*/
[[nodiscard]] std::vector<std::thread::native_handle_type> get_native_handles() const
{
std::vector<std::thread::native_handle_type> native_handles(thread_count);
for (concurrency_t i = 0; i < thread_count; ++i)
{
native_handles[i] = threads[i].native_handle();
}
return native_handles;
}
#endif
/**
* @brief Get the number of tasks currently waiting in the queue to be executed by the threads.
*
* @return The number of queued tasks.
*/
[[nodiscard]] size_t get_tasks_queued() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks.size();
}
/**
* @brief Get the number of tasks currently being executed by the threads.
*
* @return The number of running tasks.
*/
[[nodiscard]] size_t get_tasks_running() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks_running;
}
/**
* @brief Get the total number of unfinished tasks: either still waiting in the queue, or running in a thread. Note that `get_tasks_total() == get_tasks_queued() + get_tasks_running()`.
*
* @return The total number of tasks.
*/
[[nodiscard]] size_t get_tasks_total() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks_running + tasks.size();
}
/**
* @brief Get the number of threads in the pool.
*
* @return The number of threads.
*/
[[nodiscard]] concurrency_t get_thread_count() const
{
return thread_count;
}
/**
* @brief Get a vector containing the unique identifiers for each of the pool's threads, as obtained by `std::thread::get_id()`.
*
* @return The unique thread identifiers.
*/
[[nodiscard]] std::vector<std::thread::id> get_thread_ids() const
{
std::vector<std::thread::id> thread_ids(thread_count);
for (concurrency_t i = 0; i < thread_count; ++i)
{
thread_ids[i] = threads[i].get_id();
}
return thread_ids;
}
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief Check whether the pool is currently paused. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*
* @return `true` if the pool is paused, `false` if it is not paused.
*/
[[nodiscard]] bool is_paused() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return paused;
}
/**
* @brief Pause the pool. The workers will temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*/
void pause()
{
const std::scoped_lock tasks_lock(tasks_mutex);
paused = true;
}
#endif
/**
* @brief Purge all the tasks waiting in the queue. Tasks that are currently running will not be affected, but any tasks still waiting in the queue will be discarded, and will never be executed by the threads. Please note that there is no way to restore the purged tasks.
*/
void purge()
{
const std::scoped_lock tasks_lock(tasks_mutex);
while (!tasks.empty())
tasks.pop();
}
/**
* @brief Submit a function with no arguments and no return value into the task queue, with the specified priority. To push a function with arguments, enclose it in a lambda expression. Does not return a future, so the user must use `wait()` or some other method to ensure that the task finishes executing, otherwise bad things will happen.
*
* @tparam F The type of the function.
* @param task The function to push.
* @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename F>
void detach_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT)
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
tasks.emplace(std::forward<F>(task) BS_THREAD_POOL_PRIORITY_OUTPUT);
}
task_available_cv.notify_one();
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted.
* @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
detach_task(
[block = std::forward<F>(block), start = blks.start(blk), end = blks.end(blk)]
{
block(start, end);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted.
* @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
detach_task(
[loop = std::forward<F>(loop), start = blks.start(blk), end = blks.end(blk)]
{
for (T i = start; i < end; ++i)
loop(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
}
/**
* @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the sequence finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function used to define the sequence.
* @param first_index The first index in the sequence.
* @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted.
* @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT)
{
for (T i = first_index; i < index_after_last; ++i)
detach_task(
[sequence = std::forward<F>(sequence), i]
{
sequence(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
/**
* @brief Reset the pool with the total number of hardware threads available, as reported by the implementation. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*/
void reset()
{
reset(0, [] {});
}
/**
* @brief Reset the pool with a new number of threads. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param num_threads The number of threads to use.
*/
void reset(const concurrency_t num_threads)
{
reset(num_threads, [] {});
}
/**
* @brief Reset the pool with the total number of hardware threads available, as reported by the implementation, and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
void reset(const std::function<void()>& init_task)
{
reset(0, init_task);
}
/**
* @brief Reset the pool with a new number of threads and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param num_threads The number of threads to use.
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
void reset(const concurrency_t num_threads, const std::function<void()>& init_task)
{
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
std::unique_lock tasks_lock(tasks_mutex);
const bool was_paused = paused;
paused = true;
tasks_lock.unlock();
#endif
wait();
destroy_threads();
thread_count = determine_thread_count(num_threads);
threads = std::make_unique<std::thread[]>(thread_count);
create_threads(init_task);
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
tasks_lock.lock();
paused = was_paused;
#endif
}
/**
* @brief Submit a function with no arguments into the task queue, with the specified priority. To submit a function with arguments, enclose it in a lambda expression. If the function has a return value, get a future for the eventual returned value. If the function has no return value, get an `std::future<void>` which can be used to wait until the task finishes.
*
* @tparam F The type of the function.
* @tparam R The return type of the function (can be `void`).
* @param task The function to submit.
* @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A future to be used later to wait for the function to finish executing and/or obtain its returned value if it has one.
*/
template <typename F, typename R = std::invoke_result_t<std::decay_t<F>>>
[[nodiscard]] std::future<R> submit_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT)
{
const std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
detach_task(
[task = std::forward<F>(task), task_promise]
{
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
try
{
#endif
if constexpr (std::is_void_v<R>)
{
task();
task_promise->set_value();
}
else
{
task_promise->set_value(task());
}
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
}
catch (...)
{
try
{
task_promise->set_exception(std::current_exception());
}
catch (...)
{
}
}
#endif
} BS_THREAD_POOL_PRIORITY_OUTPUT);
return task_promise->get_future();
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Returns a `multi_future` that contains the futures for all of the blocks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @tparam R The return type of the function to loop through (can be `void`).
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted, and an empty `multi_future` will be returned.
* @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the blocks to finish. If the block function returns a value, the `multi_future` can also be used to obtain the values returned by each block.
*/
template <typename T, typename F, typename R = std::invoke_result_t<std::decay_t<F>, T, T>>
[[nodiscard]] multi_future<R> submit_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
multi_future<R> future;
future.reserve(blks.get_num_blocks());
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
future.push_back(submit_task(
[block = std::forward<F>(block), start = blks.start(blk), end = blks.end(blk)]
{
return block(start, end);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. It must have no return value. Returns a `multi_future` that contains the futures for all of the blocks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned.
* @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index. It cannot have a return value.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the blocks to finish.
*/
template <typename T, typename F>
[[nodiscard]] multi_future<void> submit_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
multi_future<void> future;
future.reserve(blks.get_num_blocks());
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
future.push_back(submit_task(
[loop = std::forward<F>(loop), start = blks.start(blk), end = blks.end(blk)]
{
for (T i = start; i < end; ++i)
loop(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
/**
* @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Returns a `multi_future` that contains the futures for all of the tasks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function used to define the sequence.
* @tparam R The return type of the function used to define the sequence (can be `void`).
* @param first_index The first index in the sequence.
* @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned.
* @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the tasks to finish. If the sequence function returns a value, the `multi_future` can also be used to obtain the values returned by each task.
*/
template <typename T, typename F, typename R = std::invoke_result_t<std::decay_t<F>, T>>
[[nodiscard]] multi_future<R> submit_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
multi_future<R> future;
future.reserve(static_cast<size_t>(index_after_last - first_index));
for (T i = first_index; i < index_after_last; ++i)
future.push_back(submit_task(
[sequence = std::forward<F>(sequence), i]
{
return sequence(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief Unpause the pool. The workers will resume retrieving new tasks out of the queue. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*/
void unpause()
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
paused = false;
}
task_available_cv.notify_all();
}
#endif
// Macros used internally to enable or disable pausing in the waiting and worker functions.
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
#define BS_THREAD_POOL_PAUSED_OR_EMPTY (paused || tasks.empty())
#else
#define BS_THREAD_POOL_PAUSED_OR_EMPTY tasks.empty()
#endif
/**
* @brief Wait for tasks to be completed. Normally, this function waits for all tasks, both those that are currently running in the threads and those that are still waiting in the queue. However, if the pool is paused, this function only waits for the currently running tasks (otherwise it would wait forever). Note: To wait for just one specific task, use `submit_task()` instead, and call the `wait()` member function of the generated future.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
void wait()
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
tasks_done_cv.wait(tasks_lock,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
}
/**
* @brief Wait for tasks to be completed, but stop waiting after the specified duration has passed.
*
* @tparam R An arithmetic type representing the number of ticks to wait.
* @tparam P An `std::ratio` representing the length of each tick in seconds.
* @param duration The amount of time to wait.
* @return `true` if all tasks finished running, `false` if the duration expired but some tasks are still running.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
template <typename R, typename P>
bool wait_for(const std::chrono::duration<R, P>& duration)
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
const bool status = tasks_done_cv.wait_for(tasks_lock, duration,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
return status;
}
/**
* @brief Wait for tasks to be completed, but stop waiting after the specified time point has been reached.
*
* @tparam C The type of the clock used to measure time.
* @tparam D An `std::chrono::duration` type used to indicate the time point.
* @param timeout_time The time point at which to stop waiting.
* @return `true` if all tasks finished running, `false` if the time point was reached but some tasks are still running.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
template <typename C, typename D>
bool wait_until(const std::chrono::time_point<C, D>& timeout_time)
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
const bool status = tasks_done_cv.wait_until(tasks_lock, timeout_time,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
return status;
}
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
// ==============
// Public classes
// ==============
/**
* @brief An exception that will be thrown by `wait()`, `wait_for()`, and `wait_until()` if the user tries to call them from within a thread of the same pool, which would result in a deadlock.
*/
struct wait_deadlock : public std::runtime_error
{
wait_deadlock() : std::runtime_error("BS::thread_pool::wait_deadlock"){};
};
#endif
private:
// ========================
// Private member functions
// ========================
/**
* @brief Create the threads in the pool and assign a worker to each thread.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks.
*/
void create_threads(const std::function<void()>& init_task)
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
tasks_running = thread_count;
workers_running = true;
}
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i] = std::thread(&thread_pool::worker, this, i, init_task);
}
}
/**
* @brief Destroy the threads in the pool.
*/
void destroy_threads()
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
workers_running = false;
}
task_available_cv.notify_all();
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i].join();
}
}
/**
* @brief Determine how many threads the pool should have, based on the parameter passed to the constructor or reset().
*
* @param num_threads The parameter passed to the constructor or `reset()`. If the parameter is a positive number, then the pool will be created with this number of threads. If the parameter is non-positive, or a parameter was not supplied (in which case it will have the default value of 0), then the pool will be created with the total number of hardware threads available, as obtained from `std::thread::hardware_concurrency()`. If the latter returns zero for some reason, then the pool will be created with just one thread.
* @return The number of threads to use for constructing the pool.
*/
[[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t num_threads)
{
if (num_threads > 0)
return num_threads;
if (std::thread::hardware_concurrency() > 0)
return std::thread::hardware_concurrency();
return 1;
}
/**
* @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by `detach_task()` that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies `wait()` in case it is waiting.
*
* @param idx The index of this thread.
* @param init_task An initialization function to run in this thread before it starts to execute any submitted tasks.
*/
void worker(const concurrency_t idx, const std::function<void()>& init_task)
{
this_thread::get_index.index = idx;
this_thread::get_pool.pool = this;
init_task();
std::unique_lock tasks_lock(tasks_mutex);
while (true)
{
--tasks_running;
tasks_lock.unlock();
if (waiting && (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY)
tasks_done_cv.notify_all();
tasks_lock.lock();
task_available_cv.wait(tasks_lock,
[this]
{
return !BS_THREAD_POOL_PAUSED_OR_EMPTY || !workers_running;
});
if (!workers_running)
break;
{
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
const std::function<void()> task = std::move(std::remove_const_t<pr_task&>(tasks.top()).task);
tasks.pop();
#else
const std::function<void()> task = std::move(tasks.front());
tasks.pop();
#endif
++tasks_running;
tasks_lock.unlock();
task();
}
tasks_lock.lock();
}
this_thread::get_index.index = std::nullopt;
this_thread::get_pool.pool = std::nullopt;
}
// ===============
// Private classes
// ===============
/**
* @brief A helper class to divide a range into blocks. Used by `detach_blocks()`, `submit_blocks()`, `detach_loop()`, and `submit_loop()`.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
*/
template <typename T>
class [[nodiscard]] blocks
{
public:
/**
* @brief Construct a `blocks` object with the given specifications.
*
* @param first_index_ The first index in the range.
* @param index_after_last_ The index after the last index in the range.
* @param num_blocks_ The desired number of blocks to divide the range into.
*/
blocks(const T first_index_, const T index_after_last_, const size_t num_blocks_) : first_index(first_index_), index_after_last(index_after_last_), num_blocks(num_blocks_)
{
if (index_after_last > first_index)
{
const size_t total_size = static_cast<size_t>(index_after_last - first_index);
if (num_blocks > total_size)
num_blocks = total_size;
block_size = total_size / num_blocks;
remainder = total_size % num_blocks;
if (block_size == 0)
{
block_size = 1;
num_blocks = (total_size > 1) ? total_size : 1;
}
}
else
{
num_blocks = 0;
}
}
/**
* @brief Get the first index of a block.
*
* @param block The block number.
* @return The first index.
*/
[[nodiscard]] T start(const size_t block) const
{
return first_index + static_cast<T>(block * block_size) + static_cast<T>(block < remainder ? block : remainder);
}
/**
* @brief Get the index after the last index of a block.
*
* @param block The block number.
* @return The index after the last index.
*/
[[nodiscard]] T end(const size_t block) const
{
return (block == num_blocks - 1) ? index_after_last : start(block + 1);
}
/**
* @brief Get the number of blocks. Note that this may be different than the desired number of blocks that was passed to the constructor.
*
* @return The number of blocks.
*/
[[nodiscard]] size_t get_num_blocks() const
{
return num_blocks;
}
private:
/**
* @brief The size of each block (except possibly the last block).
*/
size_t block_size = 0;
/**
* @brief The first index in the range.
*/
T first_index = 0;
/**
* @brief The index after the last index in the range.
*/
T index_after_last = 0;
/**
* @brief The number of blocks.
*/
size_t num_blocks = 0;
/**
* @brief The remainder obtained after dividing the total size by the number of blocks.
*/
size_t remainder = 0;
}; // class blocks
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
/**
* @brief A helper class to store a task with an assigned priority.
*/
class [[nodiscard]] pr_task
{
friend class thread_pool;
public:
/**
* @brief Construct a new task with an assigned priority by copying the task.
*
* @param task_ The task.
* @param priority_ The desired priority.
*/
explicit pr_task(const std::function<void()>& task_, const priority_t priority_ = 0) : task(task_), priority(priority_) {}
/**
* @brief Construct a new task with an assigned priority by moving the task.
*
* @param task_ The task.
* @param priority_ The desired priority.
*/
explicit pr_task(std::function<void()>&& task_, const priority_t priority_ = 0) : task(std::move(task_)), priority(priority_) {}
/**
* @brief Compare the priority of two tasks.
*
* @param lhs The first task.
* @param rhs The second task.
* @return `true` if the first task has a lower priority than the second task, `false` otherwise.
*/
[[nodiscard]] friend bool operator<(const pr_task& lhs, const pr_task& rhs)
{
return lhs.priority < rhs.priority;
}
private:
/**
* @brief The task.
*/
std::function<void()> task = {};
/**
* @brief The priority of the task.
*/
priority_t priority = 0;
}; // class pr_task
#endif
// ============
// Private data
// ============
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief A flag indicating whether the workers should pause. When set to `true`, the workers temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. When set to `false` again, the workers resume retrieving tasks.
*/
bool paused = false;
#endif
/**
* @brief A condition variable to notify `worker()` that a new task has become available.
*/
std::condition_variable task_available_cv = {};
/**
* @brief A condition variable to notify `wait()` that the tasks are done.
*/
std::condition_variable tasks_done_cv = {};
/**
* @brief A queue of tasks to be executed by the threads.
*/
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
std::priority_queue<pr_task> tasks = {};
#else
std::queue<std::function<void()>> tasks = {};
#endif
/**
* @brief A counter for the total number of currently running tasks.
*/
size_t tasks_running = 0;
/**
* @brief A mutex to synchronize access to the task queue by different threads.
*/
mutable std::mutex tasks_mutex = {};
/**
* @brief The number of threads in the pool.
*/
concurrency_t thread_count = 0;
/**
* @brief A smart pointer to manage the memory allocated for the threads.
*/
std::unique_ptr<std::thread[]> threads = nullptr;
/**
* @brief A flag indicating that `wait()` is active and expects to be notified whenever a task is done.
*/
bool waiting = false;
/**
* @brief A flag indicating to the workers to keep running. When set to `false`, the workers terminate permanently.
*/
bool workers_running = false;
}; // class thread_pool
} // namespace BS
#endif
\ No newline at end of file
......@@ -3,7 +3,9 @@
// clang-format off
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <ctime>
#include <numeric>
#include <stdexcept>
#include <string>
......@@ -21,11 +23,14 @@
#include <ankerl/unordered_dense.h>
#include <unordered_set>
#include "BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
#include "ygopro-core/common.h"
#include "ygopro-core/card_data.h"
#include "ygopro-core/duel.h"
#include "ygopro-core/ocgapi.h"
// clang-format on
......@@ -892,8 +897,18 @@ public:
}
};
// TODO: 7% performance loss
static std::shared_timed_mutex duel_mtx;
struct MDuel {
intptr_t pduel;
uint64_t seed;
std::vector<CardCode> main_deck0;
std::vector<CardCode> extra_deck0;
std::string deck_name0;
std::vector<CardCode> main_deck1;
std::vector<CardCode> extra_deck1;
std::string deck_name1;
};
static std::mutex duel_mtx;
inline Card db_query_card(const SQLite::Database &db, CardCode code) {
SQLite::Statement query1(db, "SELECT * FROM datas WHERE id=?");
......@@ -1237,7 +1252,7 @@ public:
"play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false));
"record"_.Bind(false), "async_reset"_.Bind(true));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1248,7 +1263,7 @@ public:
"obs:actions_"_.Bind(
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats})),
Spec<uint8_t>({conf["n_history_actions"_], n_action_feats + 2})),
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
......@@ -1301,6 +1316,10 @@ constexpr int32_t duel_options_ = ((rules_ & 0xFF) << 16) + (0 & 0xFFFF);
class YGOProEnv : public Env<YGOProEnvSpec> {
protected:
constexpr static int init_lp_ = 8000;
constexpr static int startcount_ = 5;
constexpr static int drawcount_ = 1;
std::string deck1_;
std::string deck2_;
std::vector<uint32> main_deck0_;
......@@ -1324,7 +1343,7 @@ protected:
PlayerId ai_player_;
intptr_t pduel_;
intptr_t pduel_ = 0;
Player *players_[2]; // abstract class must be pointer
std::uniform_int_distribution<uint64_t> dist_int_;
......@@ -1365,19 +1384,17 @@ protected:
uint64_t step_time_count_ = 0;
double reset_time_ = 0;
double reset_time_1_ = 0;
double reset_time_2_ = 0;
double reset_time_3_ = 0;
uint64_t reset_time_count_ = 0;
const int n_history_actions_;
// circular buffer for history actions of player 0
TArray<uint8_t> history_actions_0_;
int ha_p_0_ = 0;
std::vector<CardId> h_card_ids_0_;
// circular buffer for history actions of player 1
TArray<uint8_t> history_actions_1_;
int ha_p_1_ = 0;
std::vector<CardId> h_card_ids_1_;
// circular buffer for history actions
TArray<uint8_t> history_actions_;
int ha_p_ = 0;
std::vector<CardId> h_card_ids_;
std::unordered_set<std::string> revealed_;
......@@ -1403,32 +1420,45 @@ protected:
// MSG_SELECT_COUNTER
int n_counters_ = 0;
// async reset
const bool async_reset_;
int n_lives_ = 0;
std::future<MDuel> duel_fut_;
BS::thread_pool pool_;
std::mt19937 duel_gen_;
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1), dist_int_(0, 0xffffffff),
deck1_(spec.config["deck1"_]), deck2_(spec.config["deck2"_]),
player_(spec.config["player"_]),
player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])),
verbose_(spec.config["verbose"_]), record_(spec.config["record"_]),
n_history_actions_(spec.config["n_history_actions"_]) {
n_history_actions_(spec.config["n_history_actions"_]), pool_(BS::thread_pool(1)),
async_reset_(spec.config["async_reset"_]) {
if (record_) {
if (!verbose_) {
throw std::runtime_error("record mode must be used with verbose mode and num_envs=1");
}
// replay_data_ = new uint8_t[MAX_REPLAY_SIZE];
// rdata_ = replay_data_;
}
duel_gen_ = std::mt19937(dist_int_(gen_));
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, duel_seed=dist_int_(gen_)] {
return new_duel(duel_seed);
});
}
int max_options = spec.config["max_options"_];
int n_action_feats = spec.state_spec["obs:actions_"_].shape[1];
h_card_ids_0_.resize(max_options);
h_card_ids_1_.resize(max_options);
history_actions_0_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats})));
history_actions_1_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats})));
h_card_ids_.resize(max_options);
history_actions_ = TArray<uint8_t>(Array(
ShapeSpec(sizeof(uint8_t), {n_history_actions_, n_action_feats + 2})));
}
~YGOProEnv() {
......@@ -1452,8 +1482,36 @@ public:
play_modes_.end();
}
void update_time_stat(const clock_t& start, uint64_t time_count, double& time_stat) {
double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
time_stat = time_stat * (static_cast<double>(time_count) /
(time_count + 1)) + seconds / (time_count + 1);
}
MDuel new_duel(uint32_t seed) {
auto pduel = YGO_CreateDuel(seed);
MDuel mduel{pduel, seed};
for (PlayerId i = 0; i < 2; i++) {
YGO_SetPlayerInfo(pduel, i, init_lp_, startcount_, drawcount_);
auto [main_deck, extra_deck, deck_name] = load_deck(pduel, i, duel_gen_);
if (i == 0) {
mduel.main_deck0 = main_deck;
mduel.extra_deck0 = extra_deck;
mduel.deck_name0 = deck_name;
} else {
mduel.main_deck1 = main_deck;
mduel.extra_deck1 = extra_deck;
mduel.deck_name1 = deck_name;
}
}
YGO_StartDuel(pduel, duel_options_);
return mduel;
}
void Reset() override {
// clock_t start = clock();
clock_t start = clock();
if (random_mode()) {
play_mode_ = play_modes_[dist_int_(gen_) % play_modes_.size()];
} else {
......@@ -1471,20 +1529,29 @@ public:
turn_count_ = 0;
ms_idx_ = -1;
history_actions_0_.Zero();
history_actions_1_.Zero();
ha_p_0_ = 0;
ha_p_1_ = 0;
history_actions_.Zero();
ha_p_ = 0;
auto duel_seed = dist_int_(gen_);
clock_t _start = clock();
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
YGO_CreateDuel(duel_seed);
ulock.unlock();
intptr_t old_duel = pduel_;
MDuel mduel;
if (async_reset_) {
mduel = duel_fut_.get();
n_lives_ = 1;
} else {
mduel = new_duel(dist_int_(gen_));
}
auto duel_seed = mduel.seed;
pduel_ = mduel.pduel;
int init_lp = 8000;
int startcount = 5;
int drawcount = 1;
deck_name_[0] = mduel.deck_name0;
deck_name_[1] = mduel.deck_name1;
main_deck0_ = mduel.main_deck0;
extra_deck0_ = mduel.extra_deck0;
main_deck1_ = mduel.main_deck1;
extra_deck1_ = mduel.extra_deck1;
for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
......@@ -1496,15 +1563,13 @@ public:
}
nickname_[i] = nickname;
if ((play_mode_ == kHuman) && (i != ai_player_)) {
players_[i] = new HumanPlayer(nickname_[i], init_lp, i, verbose_);
players_[i] = new HumanPlayer(nickname_[i], init_lp_, i, verbose_);
} else if (play_mode_ == kRandomBot) {
players_[i] = new RandomAI(max_options(), dist_int_(gen_), nickname_[i],
init_lp, i, verbose_);
init_lp_, i, verbose_);
} else {
players_[i] = new GreedyAI(nickname_[i], init_lp, i, verbose_);
players_[i] = new GreedyAI(nickname_[i], init_lp_, i, verbose_);
}
YGO_SetPlayerInfo(pduel_, i, init_lp, startcount, drawcount);
load_deck(i);
lp_[i] = players_[i]->init_lp_;
}
......@@ -1553,9 +1618,9 @@ public:
fwrite(name, 40, 1, fp_);
}
ReplayWriteInt32(init_lp);
ReplayWriteInt32(startcount);
ReplayWriteInt32(drawcount);
ReplayWriteInt32(init_lp_);
ReplayWriteInt32(startcount_);
ReplayWriteInt32(drawcount_);
ReplayWriteInt32(duel_options_);
for (PlayerId i = 0; i < 2; i++) {
......@@ -1573,24 +1638,34 @@ public:
}
YGO_StartDuel(pduel_, duel_options_);
duel_started_ = true;
winner_ = 255;
win_reason_ = 255;
// update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock();
next();
done_ = false;
elapsed_step_ = 0;
WriteState(0.0);
// double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
// // update reset_time by moving average
// reset_time_ = reset_time_* (static_cast<double>(reset_time_count_) /
// (reset_time_count_ + 1)) + seconds / (reset_time_count_ + 1);
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, old_duel, duel_seed=dist_int_(gen_)] {
if (old_duel != 0) {
YGO_EndDuel(old_duel);
}
return new_duel(duel_seed);
});
}
// update_time_stat(_start, reset_time_count_, reset_time_3_);
// update_time_stat(start, reset_time_count_, reset_time_);
// reset_time_count_++;
// if (reset_time_count_ % 20 == 0) {
// fmt::println("Reset time: {:.3f}", reset_time_);
// fmt::println("Reset time: {:.3f}, {:.3f}, {:.3f}", reset_time_ * 1000, reset_time_2_ * 1000, reset_time_3_ * 1000);
// }
}
......@@ -1617,7 +1692,7 @@ public:
options_.push_back(spec);
}
} else {
ms_combs_ = combs;
ms_combs_ = combs;
_callback_multi_select_2_prepare();
}
}
......@@ -1718,23 +1793,22 @@ public:
}
void update_h_card_ids(PlayerId player, int idx) {
auto &h_card_ids = player == 0 ? h_card_ids_0_ : h_card_ids_1_;
h_card_ids[idx] = parse_card_id(options_[idx], player);
h_card_ids_[idx] = parse_card_id(options_[idx], player);
}
void update_history_actions(PlayerId player, int idx) {
auto &history_actions =
player == 0 ? history_actions_0_ : history_actions_1_;
auto &ha_p = player == 0 ? ha_p_0_ : ha_p_1_;
const auto &h_card_ids = player == 0 ? h_card_ids_0_ : h_card_ids_1_;
ha_p--;
if (ha_p < 0) {
ha_p = n_history_actions_ - 1;
if ((msg_ == MSG_SELECT_CHAIN) & (options_[idx][0] == 'c')) {
return;
}
history_actions[ha_p].Zero();
_set_obs_action(history_actions, ha_p, msg_, options_[idx], {},
h_card_ids[idx]);
ha_p_--;
if (ha_p_ < 0) {
ha_p_ = n_history_actions_ - 1;
}
history_actions_[ha_p_].Zero();
_set_obs_action(history_actions_, ha_p_, msg_, options_[idx], {},
h_card_ids_[idx]);
history_actions_[ha_p_](13) = static_cast<uint8_t>(player);
history_actions_[ha_p_](14) = static_cast<uint8_t>(turn_count_);
}
void show_deck(const std::vector<CardCode> &deck, const std::string &prefix) const {
......@@ -1764,7 +1838,7 @@ public:
}
void show_history_actions(PlayerId player) const {
const auto &ha = player == 0 ? history_actions_0_ : history_actions_1_;
const auto &ha = history_actions_;
// print card ids of history actions
for (int i = 0; i < n_history_actions_; ++i) {
fmt::print("history {}\n", i);
......@@ -1854,13 +1928,10 @@ public:
WriteState(reward, win_reason_);
// double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
// // update step_time by moving average
// step_time_ = step_time_* (static_cast<double>(step_time_count_) /
// (step_time_count_ + 1)) + seconds / (step_time_count_ + 1);
// update_time_stat(start, step_time_count_, step_time_);
// step_time_count_++;
// if (step_time_count_ % 500 == 0) {
// fmt::println("Step time: {:.3f}", step_time_);
// if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000);
// }
}
......@@ -1870,10 +1941,10 @@ private:
std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecIndex spec2index;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1;
int offset = opponent ? spec_.config["max_cards"_] : 0;
std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false},
......@@ -1982,7 +2053,7 @@ private:
feat(2) = op_lp_1;
feat(3) = op_lp_2;
feat(4) = std::min(turn_count_, 8);
feat(4) = std::min(turn_count_, 16);
feat(5) = phase2id.at(current_phase_);
feat(6) = (me == 0) ? 1 : 0;
feat(7) = (me == tp_) ? 1 : 0;
......@@ -2210,25 +2281,30 @@ private:
}
// ygopro-core API
void YGO_CreateDuel(uint32_t seed) {
intptr_t YGO_CreateDuel(uint32_t seed) {
std::mt19937 rnd(seed);
pduel_ = create_duel(rnd());
// return create_duel(rnd());
duel* pduel = new duel();
pduel->random.reset(rnd());
return (intptr_t)pduel;
}
void YGO_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) {
void YGO_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) const {
set_player_info(pduel, playerid, lp, startcount, drawcount);
}
void YGO_NewCard(intptr_t pduel, uint32 code, uint8 owner, uint8 playerid, uint8 location, uint8 sequence, uint8 position) {
void YGO_NewCard(intptr_t pduel, uint32 code, uint8 owner, uint8 playerid, uint8 location, uint8 sequence, uint8 position) const {
new_card(pduel, code, owner, playerid, location, sequence, position);
}
void YGO_StartDuel(intptr_t pduel, int32 options) {
void YGO_StartDuel(intptr_t pduel, int32 options) const {
start_duel(pduel, options);
}
void YGO_EndDuel(intptr_t pduel) {
end_duel(pduel);
void YGO_EndDuel(intptr_t pduel) const {
// end_duel(pduel);
duel* pd = (duel*)pduel;
delete pd;
}
int32 YGO_GetMessage(intptr_t pduel, byte* buf) {
......@@ -2320,34 +2396,38 @@ private:
n_options = options_.size();
state["info:num_options"_] = n_options;
// update h_card_ids from state
auto &h_card_ids = to_play_ == 0 ? h_card_ids_0_ : h_card_ids_1_;
// update_h_card_ids from state
for (int i = 0; i < n_options; ++i) {
uint8_t spec_index1 = state["obs:actions_"_](i, 0);
uint8_t spec_index2 = state["obs:actions_"_](i, 1);
uint16_t spec_index = (static_cast<uint16_t>(spec_index1) << 8) + static_cast<uint16_t>(spec_index2);
if (spec_index == 0) {
h_card_ids[i] = 0;
h_card_ids_[i] = 0;
} else {
uint8_t card_id1 = state["obs:cards_"_](spec_index - 1, 0);
uint8_t card_id2 = state["obs:cards_"_](spec_index - 1, 1);
h_card_ids[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
h_card_ids_[i] = (static_cast<uint16_t>(card_id1) << 8) + static_cast<uint16_t>(card_id2);
}
}
// write history actions
const auto &ha_p = to_play_ == 0 ? ha_p_0_ : ha_p_1_;
const auto &history_actions =
to_play_ == 0 ? history_actions_0_ : history_actions_1_;
int n1 = n_history_actions_ - ha_p;
int n_action_feats = state["obs:actions_"_].Shape()[1];
int offset = n_history_actions_ - ha_p_;
int n_h_action_feats = history_actions_.Shape()[1];
state["obs:h_actions_"_].Assign((uint8_t *)history_actions[ha_p].Data(),
n_action_feats * n1);
state["obs:h_actions_"_][n1].Assign((uint8_t *)history_actions.Data(),
n_action_feats * ha_p);
state["obs:h_actions_"_].Assign(
(uint8_t *)history_actions_[ha_p_].Data(), n_h_action_feats * offset);
state["obs:h_actions_"_][offset].Assign(
(uint8_t *)history_actions_.Data(), n_h_action_feats * ha_p_);
for (int i = 0; i < n_history_actions_; ++i) {
if (uint8_t(state["obs:h_actions_"_](i, 2)) == 0) {
break;
}
state["obs:h_actions_"_](i, 13) = static_cast<uint8_t>(uint8_t(state["obs:h_actions_"_](i, 13)) == to_play_);
int turn_diff = std::min(16, turn_count_ - uint8_t(state["obs:h_actions_"_](i, 14)));
state["obs:h_actions_"_](i, 14) = static_cast<uint8_t>(turn_diff);
}
}
void show_decision(int idx) {
......@@ -2355,47 +2435,45 @@ private:
options_);
}
void load_deck(PlayerId player, bool shuffle = true) {
std::string deck = player == 0 ? deck1_ : deck2_;
std::vector<CardCode> &main_deck = player == 0 ? main_deck0_ : main_deck1_;
std::vector<CardCode> &extra_deck =
player == 0 ? extra_deck0_ : extra_deck1_;
std::tuple<std::vector<CardCode>, std::vector<CardCode>, std::string>
load_deck(
intptr_t pduel, PlayerId player, std::mt19937& gen, bool shuffle = true) const {
std::string deck_name = player == 0 ? deck1_ : deck2_;
if (deck == "random") {
if (deck_name == "random") {
// generate random deck name
std::uniform_int_distribution<uint64_t> dist_int(0,
deck_names_.size() - 1);
deck_name_[player] = deck_names_[dist_int(gen_)];
} else {
deck_name_[player] = deck;
deck_name = deck_names_[dist_int(gen)];
}
deck = deck_name_[player];
main_deck = main_decks_.at(deck);
extra_deck = extra_decks_.at(deck);
std::vector<CardCode> main_deck = main_decks_.at(deck_name);
std::vector<CardCode> extra_deck = extra_decks_.at(deck_name);
if (verbose_) {
fmt::println("{} {}: {}, main({}), extra({})", player, nickname_[player],
deck, main_deck.size(), extra_deck.size());
deck_name, main_deck.size(), extra_deck.size());
}
if (shuffle) {
std::shuffle(main_deck.begin(), main_deck.end(), gen_);
std::shuffle(main_deck.begin(), main_deck.end(), gen);
}
// add main deck in reverse order following ygopro
// but since we have shuffled deck, so just add in order
for (int i = 0; i < main_deck.size(); i++) {
YGO_NewCard(pduel_, main_deck[i], player, player, LOCATION_DECK, 0,
YGO_NewCard(pduel, main_deck[i], player, player, LOCATION_DECK, 0,
POS_FACEDOWN_DEFENSE);
}
// add extra deck in reverse order following ygopro
for (int i = int(extra_deck.size()) - 1; i >= 0; --i) {
YGO_NewCard(pduel_, extra_deck[i], player, player, LOCATION_EXTRA, 0,
YGO_NewCard(pduel, extra_deck[i], player, player, LOCATION_EXTRA, 0,
POS_FACEDOWN_DEFENSE);
}
return {main_deck, extra_deck, deck_name};
}
void next() {
......@@ -4573,10 +4651,11 @@ private:
void _duel_end(uint8_t player, uint8_t reason) {
winner_ = player;
win_reason_ = reason;
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
YGO_EndDuel(pduel_);
ulock.unlock();
if (async_reset_) {
n_lives_--;
} else {
YGO_EndDuel(pduel_);
}
duel_started_ = false;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment