Commit 11261948 authored by Biluo Shen's avatar Biluo Shen

(WIP) OSFP

parent 4d07e48e
*.pt
*.pkl
# Xmake cache # Xmake cache
.xmake/ .xmake/
......
...@@ -140,8 +140,8 @@ if __name__ == "__main__": ...@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list = f.readlines() code_list = f.readlines()
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]): for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device) state_dict = torch.load(ckpt, map_location=device)
......
...@@ -154,7 +154,7 @@ if __name__ == "__main__": ...@@ -154,7 +154,7 @@ if __name__ == "__main__":
code_list = f.readlines() code_list = f.readlines()
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile: if not args.compile:
......
...@@ -5,6 +5,7 @@ from collections import deque ...@@ -5,6 +5,7 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Literal, Optional
import ygoenv import ygoenv
import numpy as np import numpy as np
import tyro import tyro
...@@ -247,7 +248,7 @@ def main(): ...@@ -247,7 +248,7 @@ def main():
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval() agent.eval()
if args.checkpoint: if args.checkpoint:
...@@ -274,9 +275,9 @@ def main(): ...@@ -274,9 +275,9 @@ def main():
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile) # predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device) example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
else: else:
...@@ -389,7 +390,7 @@ def main(): ...@@ -389,7 +390,7 @@ def main():
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1) value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
...@@ -403,7 +404,7 @@ def main(): ...@@ -403,7 +404,7 @@ def main():
} }
with torch.no_grad(): with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1) # value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = traced_model(v_obs)[1].reshape(v_end - v_start, -1) value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value values[v_start:v_end] = value
advantages = bootstrap_value_selfplay( advantages = bootstrap_value_selfplay(
...@@ -420,7 +421,7 @@ def main(): ...@@ -420,7 +421,7 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_learns = learns[:args.num_steps].reshape(-1) b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
# Optimizing the policy and value network # Optimizing the policy and value network
......
...@@ -243,7 +243,7 @@ def main(): ...@@ -243,7 +243,7 @@ def main():
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint: if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device)) agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
......
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
return logits, value
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
collect_start = time.time()
while step < args.collect_length:
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1)
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - _start
_start = time.time()
# flatten the batch
b_obs = {
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
main()
...@@ -102,8 +102,10 @@ class Args: ...@@ -102,8 +102,10 @@ class Args:
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None target_kl: Optional[float] = None
"""the target KL divergence threshold""" """the target KL divergence threshold"""
learn_opponent: bool = True learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent""" """if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl" backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training""" """the backend for distributed training"""
...@@ -161,6 +163,9 @@ def main(): ...@@ -161,6 +163,9 @@ def main():
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size local_env_threads = args.env_threads // args.world_size
...@@ -248,7 +253,7 @@ def main(): ...@@ -248,7 +253,7 @@ def main():
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval() agent.eval()
if args.checkpoint: if args.checkpoint:
...@@ -260,22 +265,19 @@ def main(): ...@@ -260,22 +265,19 @@ def main():
if args.embedding_file: if args.embedding_file:
agent.freeze_embeddings() agent.freeze_embeddings()
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
optim_params = list(agent.parameters()) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device) def predict_step(agent: Agent, next_obs):
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs) logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value return logits, value
from ygoai.rl.ppo import train_step from ygoai.rl.ppo import train_step
...@@ -289,15 +291,18 @@ def main(): ...@@ -289,15 +291,18 @@ def main():
traced_model_t = torch.jit.optimize_for_inference(traced_model_t) traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device) actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device) logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device) rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device) dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device) values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device) learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
version = 0 version = 0
...@@ -318,6 +323,7 @@ def main(): ...@@ -318,6 +323,7 @@ def main():
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0 next_value = 0
step = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -329,7 +335,7 @@ def main(): ...@@ -329,7 +335,7 @@ def main():
model_time = 0 model_time = 0
env_time = 0 env_time = 0
collect_start = time.time() collect_start = time.time()
for step in range(0, args.num_steps): while step < args.collect_length:
global_step += args.num_envs global_step += args.num_envs
for key in obs: for key in obs:
...@@ -339,7 +345,10 @@ def main(): ...@@ -339,7 +345,10 @@ def main():
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
logits, value = predict_step(traced_model, traced_model_t, next_obs, learn) logits, value = predict_step(traced_model, next_obs)
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -362,6 +371,7 @@ def main(): ...@@ -362,6 +371,7 @@ def main():
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward, device) rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool) next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer: if not writer:
continue continue
...@@ -390,6 +400,8 @@ def main(): ...@@ -390,6 +400,8 @@ def main():
if local_rank == 0: if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
...@@ -397,23 +409,36 @@ def main(): ...@@ -397,23 +409,36 @@ def main():
value_t = traced_model_t(next_obs)[1].reshape(-1) value_t = traced_model_t(next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t) value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value) nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
if step > 0 and iteration != 1:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_self( advantages = bootstrap_value_self(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda) values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start bootstrap_time = time.time() - _start
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
b_obs = { b_obs = {
k: v.reshape((-1,) + v.shape[2:]) k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items() for k, v in obs.items()
} }
b_logprobs = logprobs.reshape(-1) b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_actions = actions.reshape((-1,) + action_shape) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages.reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_returns = returns.reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_values = values.reshape(-1) b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_learns = learns.reshape(-1) b_returns = b_advantages + b_values
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -437,7 +462,14 @@ def main(): ...@@ -437,7 +462,14 @@ def main():
if args.target_kl is not None and approx_kl > args.target_kl: if args.target_kl is not None and approx_kl > args.target_kl:
break break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start train_time = time.time() - _start
if local_rank == 0: if local_rank == 0:
...@@ -497,7 +529,7 @@ def main(): ...@@ -497,7 +529,7 @@ def main():
_start = time.time() _start = time.time()
eval_return = evaluate( eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval) eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device) eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics # sync the statistics
......
...@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module): ...@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False, affine=True):
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.channels = channels self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels c = channels
self.loc_embed = nn.Embedding(9, c) self.loc_embed = nn.Embedding(9, c)
...@@ -165,11 +163,17 @@ class Encoder(nn.Module): ...@@ -165,11 +163,17 @@ class Encoder(nn.Module):
for i in range(num_action_layers) for i in range(num_action_layers)
]) ])
self.action_history_pe = PositionalEncoding(c, dropout=0.0) self.history_action_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([ self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer( nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False) c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_history_action_layers) for i in range(num_action_layers)
]) ])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False) self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
...@@ -287,6 +291,7 @@ class Encoder(nn.Module): ...@@ -287,6 +291,7 @@ class Encoder(nn.Module):
x_cards = x['cards_'] x_cards = x['cards_']
x_global = x['global_'] x_global = x['global_']
x_actions = x['actions_'] x_actions = x['actions_']
batch_size = x_cards.shape[0]
x_cards_1 = x_cards[:, :, :12].long() x_cards_1 = x_cards[:, :, :12].long()
x_cards_2 = x_cards[:, :, 12:].to(torch.float32) x_cards_2 = x_cards[:, :, 12:].to(torch.float32)
...@@ -294,7 +299,10 @@ class Encoder(nn.Module): ...@@ -294,7 +299,10 @@ class Encoder(nn.Module):
x_id = self.encode_card_id(x_cards_1[:, :, :2]) x_id = self.encode_card_id(x_cards_1[:, :, :2])
x_id = self.id_norm(x_id) x_id = self.id_norm(x_id)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 2])) x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask[:, 0] = False
f_loc = self.loc_norm(self.loc_embed(x_loc))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3])) f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3]))
x_feat1 = self.encode_card_feat1(x_cards_1) x_feat1 = self.encode_card_feat1(x_cards_1)
...@@ -306,11 +314,14 @@ class Encoder(nn.Module): ...@@ -306,11 +314,14 @@ class Encoder(nn.Module):
f_cards = torch.cat([x_id, x_feat], dim=-1) f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1) for layer in self.card_net:
# f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_na_card = self.na_card_embed.expand(batch_size, -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1) f_cards = torch.cat([f_na_card, f_cards], dim=1)
# TODO: we can't use it because cudagraph says complex memory
# c_mask = torch.cat([torch.zeros(batch_size, 1, dtype=c_mask.dtype, device=c_mask.device), c_mask], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards) f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global) x_global = self.encode_global(x_global)
...@@ -334,21 +345,24 @@ class Encoder(nn.Module): ...@@ -334,21 +345,24 @@ class Encoder(nn.Module):
valid = x['global_'][:, -1] == 0 valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid mask[:, 0] &= valid
for layer in self.action_card_net: for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask) f_actions = layer(
f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask)
if self.num_history_action_layers != 0: x_h_actions = x['h_actions_']
x_h_actions = x['h_actions_'] x_h_actions = x_h_actions.long()
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
x_h_id = self.get_h_action_card_(x_h_actions[..., :2]) h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask[:, 0] = False
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:]) x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1) x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats) f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions) f_h_actions = self.history_action_pe(f_h_actions)
for layer in self.action_history_net: for layer in self.history_action_net:
f_actions = layer(f_actions, f_h_actions) f_h_actions = layer(f_h_actions, src_key_padding_mask=h_mask)
for layer in self.action_history_net:
f_actions = layer(
f_actions, f_h_actions, tgt_key_padding_mask=mask, memory_key_padding_mask=h_mask)
f_actions = self.action_norm(f_actions) f_actions = self.action_norm(f_actions)
...@@ -385,13 +399,12 @@ class Actor(nn.Module): ...@@ -385,13 +399,12 @@ class Actor(nn.Module):
class PPOAgent(nn.Module): class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False,
num_history_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True): affine=True, a_trans=True):
super(PPOAgent, self).__init__() super(PPOAgent, self).__init__()
self.encoder = Encoder( self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine) channels, num_card_layers, num_action_layers, embedding_shape, bias, affine)
c = channels c = channels
self.actor = Actor(c, a_trans) self.actor = Actor(c, a_trans)
......
...@@ -11,8 +11,7 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv ...@@ -11,8 +11,7 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions) newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy() entropy = probs.entropy()
if not args.learn_opponent: valid = torch.logical_and(valid, mb_learns)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs logratio = newlogprob - mb_logprobs
ratio = logratio.exp() ratio = logratio.exp()
......
...@@ -1870,10 +1870,10 @@ private: ...@@ -1870,10 +1870,10 @@ private:
std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) { std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecIndex spec2index; SpecIndex spec2index;
std::vector<int> loc_n_cards; std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) { for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2; const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1; const bool opponent = pi == 1;
int offset = opponent ? spec_.config["max_cards"_] : 0;
std::vector<std::pair<uint8_t, bool>> configs = { std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true}, {LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false}, {LOCATION_MZONE, false}, {LOCATION_SZONE, false},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment