Commit 3c7a7080 authored by biluo.shen's avatar biluo.shen

Add PPO

parent ad9c3c34
...@@ -60,17 +60,17 @@ class Args: ...@@ -60,17 +60,17 @@ class Args:
total_timesteps: int = 100000000 total_timesteps: int = 100000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
num_envs: int = 64 num_envs: int = 64
"""the number of parallel game environments""" """the number of parallel game environments"""
num_steps: int = 100 num_steps: int = 200
"""the number of steps per env per iteration""" """the number of steps per env per iteration"""
buffer_size: int = 200000 buffer_size: int = 20000
"""the replay memory buffer size""" """the replay memory buffer size"""
gamma: float = 0.99 gamma: float = 0.99
"""the discount factor gamma""" """the discount factor gamma"""
minibatch_size: int = 256 minibatch_size: int = 1024
"""the mini-batch size""" """the mini-batch size"""
eps: float = 0.05 eps: float = 0.05
"""the epsilon for exploration""" """the epsilon for exploration"""
...@@ -264,13 +264,13 @@ if __name__ == "__main__": ...@@ -264,13 +264,13 @@ if __name__ == "__main__":
# ALGO LOGIC: training. # ALGO LOGIC: training.
_start = time.time() _start = time.time()
b_inds = rb.get_data_indices() if not rb.full:
if len(b_inds) < args.minibatch_size:
continue continue
b_inds = rb.get_data_indices()
np.random.shuffle(b_inds) np.random.shuffle(b_inds)
b_obs, b_actions, b_returns = rb._get_samples(b_inds) b_obs, b_actions, b_returns = rb._get_samples(b_inds)
sample_time += time.time() - _start sample_time += time.time() - _start
for start in range(0, len(b_inds), args.minibatch_size): for start in range(0, len(b_returns), args.minibatch_size):
_start = time.time() _start = time.time()
end = start + args.minibatch_size end = start + args.minibatch_size
mb_obs = { mb_obs = {
......
import os
import random
import time
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, mp_start, setup
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
play_mode: str = "self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
total_timesteps: int = 100000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 4
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: bool = True
"""whether to use torch.compile to compile the model and functions"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
port: int = 12355
"""the port to use for distributed training"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
def run(local_rank, world_size):
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if local_rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
print(f"obs_space={obs_space}, action_shape={action_shape}")
envs = RecordEpisodeStatistics(envs)
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent.load_embeddings(embeddings)
if args.compile:
agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode='reduce-overhead')
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def train_step(agent, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
if args.compile:
train_step = torch.compile(train_step, mode='reduce-overhead')
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
avg_ep_returns = []
avg_win_rates = []
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs = to_tensor(envs.reset()[0], dtype=torch.uint8)
next_done = torch.zeros(args.local_num_envs, device=device)
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
collect_start = time.time()
for step in range(0, args.num_steps):
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
_start = time.time()
with torch.no_grad():
action, logprob, _, value, valid = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
_start = time.time()
next_obs, reward, next_done_, info = envs.step(action)
env_time += time.time() - _start
rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_)
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
avg_ep_returns.append(episode_reward)
winner = 0 if episode_reward > 0 else 1
avg_win_rates.append(1 - winner)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > 100:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = []
avg_ep_returns = []
collect_time = time.time() - collect_start
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
_start = time.time()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds])
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
train_time = time.time() - _start
if local_rank == 0:
print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step)
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if local_rank == 0:
writer.close()
if __name__ == "__main__":
mp_start(run)
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical
def bytes_to_bin(x, points, intervals): def bytes_to_bin(x, points, intervals):
...@@ -18,11 +19,11 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24): ...@@ -18,11 +19,11 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
return points, intervals return points, intervals
class Agent(nn.Module): class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True): num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Agent, self).__init__() super(Encoder, self).__init__()
self.num_history_action_layers = num_history_action_layers self.num_history_action_layers = num_history_action_layers
c = channels c = channels
...@@ -129,11 +130,6 @@ class Agent(nn.Module): ...@@ -129,11 +130,6 @@ class Agent(nn.Module):
]) ])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False) self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.init_embeddings() self.init_embeddings()
...@@ -147,7 +143,6 @@ class Agent(nn.Module): ...@@ -147,7 +143,6 @@ class Agent(nn.Module):
nn.init.uniform_(m.weight, -scale, scale) nn.init.uniform_(m.weight, -scale, scale)
elif "fc_emb" in n: elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale) nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings, freeze=True):
weight = self.id_embed.weight weight = self.id_embed.weight
...@@ -309,7 +304,78 @@ class Agent(nn.Module): ...@@ -309,7 +304,78 @@ class Agent(nn.Module):
f_actions = layer(f_actions, f_h_actions) f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions) f_actions = self.action_norm(f_actions)
return f_actions, mask, valid
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.critic = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
return self.critic(f)
def get_action_and_value(self, x, action=None):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0]
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid
class DMCAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(DMCAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.value_head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def forward(self, x):
f_actions, mask, valid = self.encoder(f_actions)
values = self.value_head(f_actions)[..., 0] values = self.value_head(f_actions)[..., 0]
# values = torch.tanh(values) # values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -10), values) values = torch.where(mask, torch.full_like(values, -10), values)
return values, valid return values, valid
\ No newline at end of file
\ No newline at end of file
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def reduce_gradidents(model, world_size):
if world_size == 1:
return
all_grads_list = []
for param in model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in model.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size
)
offset += param.numel()
def setup(backend, rank, world_size, port):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port)
dist.init_process_group(backend, rank=rank, world_size=world_size)
def mp_start(run):
world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1:
run(local_rank=0, world_size=world_size)
else:
children = []
for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size))
children.append(subproc)
subproc.start()
for i in range(world_size):
children[i].join()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment