Commit 7e87d523 authored by biluo.shen's avatar biluo.shen

Refactor

parent 3ac7492c
...@@ -24,7 +24,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL). ...@@ -24,7 +24,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Usage ## Usage
TODO
### Serialize agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension.
```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize
```
### Battle between two agents
We can use `battle.py` to let two agents play against each other and find out which one is better.
```bash
python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/1234_1000M.ptj --checkpoint2 checkpoints/9876_100M.ptj --num-episodes=256 --num_envs=32 --seed 0
```
### Running ### Running
TODO TODO
......
# LSTM Implementations
## Original PPO + LSTM in CleanRL
```python
not_done = (~done.reshape((-1, batch_size))).float()
new_hidden = []
for i in range(hidden.shape[0]):
h, lstm_state = self.lstm(
hidden[i].unsqueeze(0),
(
not_done[i].view(1, -1, 1) * lstm_state[0],
not_done[i].view(1, -1, 1) * lstm_state[1],
),
)
new_hidden += [h]
new_hidden = torch.cat(new_hidden)
# new_hidden, lstm_state = self.lstm(hidden, lstm_state)
```
The length of the loop is the `num_steps` (typically 128), therefore it is slow (even with torch.compile). Compared with the original LSTM, the overall training time is 4x slower.
## Custom LSTM with triton
```python
```
\ No newline at end of file
import os
import random
import time
from typing import Optional
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, Elo
from ygoai.rl.agent import Agent
from ygoai.rl.buffer import DMCDictBuffer
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
play_mode: str = "self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
total_timesteps: int = 100000000
"""total timesteps of the experiments"""
learning_rate: float = 5e-4
"""the learning rate of the optimizer"""
num_envs: int = 64
"""the number of parallel game environments"""
num_steps: int = 200
"""the number of steps per env per iteration"""
buffer_size: int = 20000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
minibatch_size: int = 1024
"""the mini-batch size"""
eps: float = 0.05
"""the epsilon for exploration"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
log_p: float = 0.1
"""the probability of logging"""
save_freq: int = 100
"""the saving frequency (in terms of iterations)"""
compile: bool = True
"""if toggled, model will be compiled for better performance"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 32
"""the number of threads to use for envpool, defaults to `num_envs`"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
# to be filled in runtime
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = args.num_envs * args.num_steps
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.num_envs,
num_threads=args.env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
envs.num_envs = args.num_envs
obs_space = envs.observation_space
action_space = envs.action_space
print(f"obs_space={obs_space}, action_space={action_space}")
envs = RecordEpisodeStatistics(envs)
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent.load_embeddings(embeddings)
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
avg_win_rates = []
avg_ep_returns = []
# elo = Elo()
selfplay = "self" in args.play_mode
rb = DMCDictBuffer(
args.buffer_size,
obs_space,
action_space,
device=device,
n_envs=args.num_envs,
selfplay=selfplay,
)
gamma = np.float32(args.gamma)
global_step = 0
start_time = time.time()
warmup_steps = 0
obs, infos = envs.reset()
num_options = infos['num_options']
to_play = infos['to_play'] if selfplay else None
for iteration in range(1, args.num_iterations + 1):
agent.eval()
model_time = 0
env_time = 0
buffer_time = 0
collect_start = time.time()
for step in range(args.num_steps):
global_step += args.num_envs
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
if random.random() < args.eps:
actions_ = np.random.randint(num_options)
actions = torch.from_numpy(actions_).to(device)
else:
_start = time.time()
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1)
actions_ = actions.cpu().numpy()
model_time += time.time() - _start
_start = time.time()
next_obs, rewards, dones, infos = envs.step(actions_)
env_time += time.time() - _start
num_options = infos['num_options']
_start = time.time()
rb.add(obs, actions, rewards, to_play)
buffer_time += time.time() - _start
for idx, d in enumerate(dones):
if d:
_start = time.time()
rb.mark_episode(idx, gamma)
buffer_time += time.time() - _start
if random.random() < args.log_p:
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
writer.add_scalar("charts/episodic_return", episode_reward, global_step)
writer.add_scalar("charts/episodic_length", episode_length, global_step)
if selfplay:
if infos['is_selfplay'][idx]:
# win rate for the first player
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
avg_win_rates.append(1 - winner)
else:
# win rate of agent
winner = 0 if episode_reward > 0 else 1
# elo.update(winner)
else:
avg_ep_returns.append(episode_reward)
winner = 0 if episode_reward > 0 else 1
avg_win_rates.append(1 - winner)
# elo.update(winner)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > 100:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = []
avg_ep_returns = []
to_play = infos['to_play'] if selfplay else None
obs = next_obs
collect_time = time.time() - collect_start
print(f"global_step={global_step}, collect_time={collect_time}, model_time={model_time}, env_time={env_time}, buffer_time={buffer_time}")
agent.train()
train_start = time.time()
model_time = 0
sample_time = 0
# ALGO LOGIC: training.
_start = time.time()
if not rb.full:
continue
b_inds = rb.get_data_indices()
np.random.shuffle(b_inds)
b_obs, b_actions, b_returns = rb._get_samples(b_inds)
sample_time += time.time() - _start
for start in range(0, len(b_returns), args.minibatch_size):
_start = time.time()
end = start + args.minibatch_size
mb_obs = {
k: v[start:end] for k, v in b_obs.items()
}
mb_actions = b_actions[start:end]
mb_returns = b_returns[start:end]
sample_time += time.time() - _start
_start = time.time()
outputs, valid = agent(mb_obs)
outputs = torch.gather(outputs, 1, mb_actions).squeeze(1)
outputs = torch.where(valid, outputs, mb_returns)
loss = F.mse_loss(mb_returns, outputs)
loss = loss * (args.minibatch_size / valid.float().sum())
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
model_time += time.time() - _start
train_time = time.time() - train_start
print(f"global_step={global_step}, train_time={train_time}, model_time={model_time}, sample_time={sample_time}")
writer.add_scalar("losses/value_loss", loss.item(), global_step)
writer.add_scalar("losses/q_values", outputs.mean().item(), global_step)
if not rb.full or iteration % 10 == 0:
torch.cuda.empty_cache()
if iteration == 10:
warmup_steps = global_step
start_time = time.time()
if iteration > 10:
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.save_freq == 0:
save_path = f"checkpoints/agent.pt"
print(f"Saving model to {save_path}")
torch.save(agent.state_dict(), save_path)
envs.close()
writer.close()
\ No newline at end of file
import os
import random
import time
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, Elo
from ygoai.rl.agent import Agent
from ygoai.rl.buffer import DMCDictBuffer
from ygoai.rl.dist import reduce_gradidents
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 8
"""the number of history actions to use"""
play_mode: str = "self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
total_timesteps: int = 100000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 64
"""the number of parallel game environments"""
num_steps: int = 100
"""the number of steps per env per iteration"""
buffer_size: int = 200000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
minibatch_size: int = 256
"""the mini-batch size"""
eps: float = 0.05
"""the epsilon for exploration"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
log_p: float = 0.1
"""the probability of logging"""
save_freq: int = 100
"""the saving frequency (in terms of iterations)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: bool = True
"""if toggled, model will be compiled for better performance"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 32
"""the number of threads to use for envpool, defaults to `num_envs`"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
port: int = 12355
"""the port to use for distributed training"""
# to be filled in runtime
local_buffer_size: int = 0
"""the local buffer size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
def setup(backend, rank, world_size, port):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port)
dist.init_process_group(backend, rank=rank, world_size=world_size)
def run(local_rank, world_size):
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.local_buffer_size = int(args.buffer_size // args.world_size)
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if local_rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_space = envs.action_space
if local_rank == 0:
print(f"obs_space={obs_space}, action_space={action_space}")
envs = RecordEpisodeStatistics(envs)
embeddings = np.load(args.embedding_file)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent.load_embeddings(embeddings)
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
avg_win_rates = []
avg_ep_returns = []
# elo = Elo()
selfplay = "self" in args.play_mode
rb = DMCDictBuffer(
args.local_buffer_size,
obs_space,
action_space,
device=device,
n_envs=args.local_num_envs,
selfplay=selfplay,
)
gamma = np.float32(args.gamma)
global_step = 0
warmup_steps = 0
start_time = time.time()
obs, infos = envs.reset()
num_options = infos['num_options']
to_play = infos['to_play'] if selfplay else None
for iteration in range(1, args.num_iterations + 1):
agent.eval()
model_time = 0
env_time = 0
buffer_time = 0
collect_start = time.time()
for step in range(args.num_steps):
global_step += args.num_envs
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
if random.random() < args.eps:
actions_ = np.random.randint(num_options)
actions = torch.from_numpy(actions_).to(device)
else:
_start = time.time()
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1)
actions_ = actions.cpu().numpy()
model_time += time.time() - _start
_start = time.time()
next_obs, rewards, dones, infos = envs.step(actions_)
env_time += time.time() - _start
num_options = infos['num_options']
_start = time.time()
rb.add(obs, actions, rewards, to_play)
buffer_time += time.time() - _start
obs = next_obs
to_play = infos['to_play'] if selfplay else None
for idx, d in enumerate(dones):
if d:
_start = time.time()
rb.mark_episode(idx, gamma)
buffer_time += time.time() - _start
if writer and random.random() < args.log_p:
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
writer.add_scalar("charts/episodic_return", episode_reward, global_step)
writer.add_scalar("charts/episodic_length", episode_length, global_step)
if selfplay:
if infos['is_selfplay'][idx]:
# win rate for the first player
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
avg_win_rates.append(1 - winner)
else:
# win rate of agent
winner = 0 if episode_reward > 0 else 1
# elo.update(winner)
else:
avg_ep_returns.append(episode_reward)
winner = 0 if episode_reward > 0 else 1
avg_win_rates.append(1 - winner)
# elo.update(winner)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > 100:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = []
avg_ep_returns = []
collect_time = time.time() - collect_start
if writer:
print(f"global_step={global_step}, collect_time={collect_time}, model_time={model_time}, env_time={env_time}, buffer_time={buffer_time}")
agent.train()
train_start = time.time()
model_time = 0
sample_time = 0
# ALGO LOGIC: training.
_start = time.time()
if not rb.full:
continue
b_inds = rb.get_data_indices()
np.random.shuffle(b_inds)
b_obs, b_actions, b_returns = rb._get_samples(b_inds)
print(f"{len(b_inds)}, {b_returns.shape}, {args.local_buffer_size}, {args.local_minibatch_size}")
n_samples = torch.tensor(b_returns.shape[0], device=device, dtype=torch.int64)
dist.all_reduce(n_samples, op=dist.ReduceOp.MIN)
n_samples = n_samples.item()
print(f"n_samples={n_samples}")
raise ValueError
sample_time += time.time() - _start
for start in range(0, len(b_returns), args.local_minibatch_size):
_start = time.time()
end = start + args.local_minibatch_size
mb_obs = {
k: v[start:end] for k, v in b_obs.items()
}
mb_actions = b_actions[start:end]
mb_returns = b_returns[start:end]
sample_time += time.time() - _start
_start = time.time()
outputs, valid = agent(mb_obs)
outputs = torch.gather(outputs, 1, mb_actions).squeeze(1)
outputs = torch.where(valid, outputs, mb_returns)
loss = F.mse_loss(mb_returns, outputs)
loss = loss * (args.local_minibatch_size / valid.float().sum())
optimizer.zero_grad()
loss.backward()
reduce_gradidents(agent)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
model_time += time.time() - _start
if not rb.full or iteration % 10 == 0:
torch.cuda.empty_cache()
train_time = time.time() - train_start
if writer:
print(f"global_step={global_step}, train_time={train_time}, model_time={model_time}, sample_time={sample_time}")
writer.add_scalar("losses/value_loss", loss.item(), global_step)
writer.add_scalar("losses/q_values", outputs.mean().item(), global_step)
if iteration == 10:
warmup_steps = global_step
start_time = time.time()
if iteration > 10:
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.save_freq == 0:
save_path = f"checkpoints/agent.pt"
print(f"Saving model to {save_path}")
torch.save(agent.state_dict(), save_path)
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if writer:
writer.close()
if __name__ == "__main__":
world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1:
run(local_rank=0, world_size=world_size)
else:
children = []
for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size))
children.append(subproc)
subproc.start()
for i in range(world_size):
children[i].join()
...@@ -149,7 +149,7 @@ if __name__ == "__main__": ...@@ -149,7 +149,7 @@ if __name__ == "__main__":
code_list = f.readlines() code_list = f.readlines()
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval() # agent = agent.eval()
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
......
import os import os
import random import random
import time import time
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Literal, Optional
import ygoenv import ygoenv
import numpy as np import numpy as np
import optree
import tyro import tyro
import torch import torch
...@@ -18,10 +18,12 @@ import torch.distributed as dist ...@@ -18,10 +18,12 @@ import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, mp_start, setup from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.eval import evaluate
@dataclass @dataclass
...@@ -38,7 +40,7 @@ class Args: ...@@ -38,7 +40,7 @@ class Args:
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
"""the deck file for the first player""" """the deck file for the first player"""
...@@ -46,21 +48,23 @@ class Args: ...@@ -46,21 +48,23 @@ class Args:
"""the deck file for the second player""" """the deck file for the second player"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
embedding_file: Optional[str] = "embeddings_en.npy" embedding_file: Optional[str] = None
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "self" play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'""" """the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 1000000000 total_timesteps: int = 2000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 2.5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
...@@ -70,10 +74,11 @@ class Args: ...@@ -70,10 +74,11 @@ class Args:
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99 gamma: float = 0.997
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
...@@ -92,9 +97,11 @@ class Args: ...@@ -92,9 +97,11 @@ class Args:
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None target_kl: Optional[float] = None
"""the target KL divergence threshold""" """the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl" backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training""" """the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -114,8 +121,10 @@ class Args: ...@@ -114,8 +121,10 @@ class Args:
"""the number of iterations to save the model""" """the number of iterations to save the model"""
log_p: float = 1.0 log_p: float = 1.0
"""the probability of logging""" """the probability of logging"""
port: int = 12356 eval_episodes: int = 128
"""the port to use for distributed training""" """the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
local_batch_size: int = 0 local_batch_size: int = 0
...@@ -132,7 +141,12 @@ class Args: ...@@ -132,7 +141,12 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def run(local_rank, world_size): def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args) args = tyro.cli(Args)
args.world_size = world_size args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size args.local_num_envs = args.num_envs // args.world_size
...@@ -150,12 +164,12 @@ def run(local_rank, world_size): ...@@ -150,12 +164,12 @@ def run(local_rank, world_size):
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port) torchrun_setup(args.backend, local_rank)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None writer = None
if local_rank == 0: if rank == 0:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name)) writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text( writer.add_text(
...@@ -169,10 +183,10 @@ def run(local_rank, world_size): ...@@ -169,10 +183,10 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: seeding # TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker # CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank args.seed += rank
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank) torch.manual_seed(args.seed - rank)
if args.torch_deterministic: if args.torch_deterministic:
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
else: else:
...@@ -180,7 +194,7 @@ def run(local_rank, world_size): ...@@ -180,7 +194,7 @@ def run(local_rank, world_size):
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu") device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
...@@ -195,15 +209,33 @@ def run(local_rank, world_size): ...@@ -195,15 +209,33 @@ def run(local_rank, world_size):
deck2=args.deck2, deck2=args.deck2,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode=args.play_mode, play_mode='self',
) )
envs.num_envs = args.local_num_envs envs.num_envs = args.local_num_envs
obs_space = envs.observation_space obs_space = envs.observation_space
action_shape = envs.action_space.shape action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
print(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = np.load(args.embedding_file) embeddings = np.load(args.embedding_file)
...@@ -211,91 +243,66 @@ def run(local_rank, world_size): ...@@ -211,91 +243,66 @@ def run(local_rank, world_size):
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.embedding_file: agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings) agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid): def predict_step(agent: Agent, next_obs):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def train_step(agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
with autocast(enabled=args.fp16_train):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits, values = agent(next_obs) logits, value, valid = agent(next_obs)
return logits, values return logits, value
from ygoai.rl.ppo import train_step
if args.compile: if args.compile:
train_step = torch.compile(train_step, mode=args.compile_mode) # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
predict_step = torch.compile(predict_step, mode=args.compile_mode) # predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
def to_tensor(x, dtype=torch.float32): train_step = torch.compile(train_step, mode=args.compile)
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device) actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device) logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device) dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device) values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
avg_ep_returns = [] learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_win_rates = [] avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
# TRY NOT TO MODIFY: start the game # TRY NOT TO MODIFY: start the game
global_step = 0 global_step = 0
warmup_steps = 0 warmup_steps = 0
start_time = time.time() start_time = time.time()
next_obs = to_tensor(envs.reset()[0], dtype=torch.uint8) next_obs, info = envs.reset()
next_done = torch.zeros(args.local_num_envs, device=device) next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -313,68 +320,72 @@ def run(local_rank, world_size): ...@@ -313,68 +320,72 @@ def run(local_rank, world_size):
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
_start = time.time() _start = time.time()
torch._inductor.cudagraph_mark_step_begin() logits, value = predict_step(traced_model, next_obs)
logits, value = predict_step(agent, next_obs) value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
logprob = probs.log_prob(action) logprob = probs.log_prob(action)
values[step] = value.flatten() values[step] = value
actions[step] = action actions[step] = action
logprobs[step] = logprob logprobs[step] = logprob
action = action.cpu().numpy() action = action.cpu().numpy()
model_time += time.time() - _start model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time() _start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action) next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward) rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_) next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
if not writer: if not writer:
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] episode_reward = info['r'][idx] * pl
winner = 0 if episode_reward > 0 else 1 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
avg_win_rates.append(1 - winner) avg_win_rates.append(win)
if random.random() < args.log_p: if random.random() < args.log_p:
n = 100 n = 100
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > n: if random.random() < 1/n:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step) writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = [] writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
avg_ep_returns = []
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1) value = traced_model(next_obs)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
lastgaelam = 0 nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
for t in reversed(range(args.num_steps)): advantages = bootstrap_value_selfplay(
if t == args.num_steps - 1: values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
...@@ -387,6 +398,7 @@ def run(local_rank, world_size): ...@@ -387,6 +398,7 @@ def run(local_rank, world_size):
b_advantages = advantages.reshape(-1) b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1) b_returns = returns.reshape(-1)
b_values = values.reshape(-1) b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -399,12 +411,11 @@ def run(local_rank, world_size): ...@@ -399,12 +411,11 @@ def run(local_rank, world_size):
mb_obs = { mb_obs = {
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
torch._inductor.cudagraph_mark_step_begin()
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(agent, args.world_size) reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
...@@ -414,17 +425,16 @@ def run(local_rank, world_size): ...@@ -414,17 +425,16 @@ def run(local_rank, world_size):
train_time = time.time() - _start train_time = time.time() - _start
print(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}", flush=True) if local_rank == 0:
# if local_rank == 0: fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
# print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if rank == 0:
if iteration % args.save_interval == 0 or iteration == args.num_iterations: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
...@@ -444,15 +454,37 @@ def run(local_rank, world_size): ...@@ -444,15 +454,37 @@ def run(local_rank, world_size):
start_time = time.time() start_time = time.time()
warmup_steps = global_step warmup_steps = global_step
if iteration > SPS_warmup_iters: if iteration > SPS_warmup_iters:
print("SPS:", SPS) if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1: if args.world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
if local_rank == 0: if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close() writer.close()
if __name__ == "__main__": if __name__ == "__main__":
mp_start(run) main()
...@@ -20,7 +20,7 @@ from torch.cuda.amp import GradScaler, autocast ...@@ -20,7 +20,7 @@ from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent2 import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
...@@ -136,6 +136,8 @@ class Args: ...@@ -136,6 +136,8 @@ class Args:
"""the number of iterations (computed in runtime)""" """the number of iterations (computed in runtime)"""
world_size: int = 0 world_size: int = 0
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
num_minibatches: int = 0
"""the number of mini-batches (computed in runtime)"""
def main(): def main():
...@@ -151,6 +153,7 @@ def main(): ...@@ -151,6 +153,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size) args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps) args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
...@@ -240,7 +243,7 @@ def main(): ...@@ -240,7 +243,7 @@ def main():
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.checkpoint: if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device)) agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
...@@ -268,9 +271,9 @@ def main(): ...@@ -268,9 +271,9 @@ def main():
std = (var + eps).sqrt() std = (var + eps).sqrt()
return (x - mean) / std return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns): def train_step(agent: Agent, scaler, mb_obs, lstm_state, mb_dones, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train): with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs) logits, newvalue, valid, _ = agent(mb_obs, lstm_state, mb_dones)
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions) newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy() entropy = probs.entropy()
...@@ -315,18 +318,23 @@ def main(): ...@@ -315,18 +318,23 @@ def main():
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent: Agent, next_obs): def predict_step(agent: Agent, next_obs, next_lstm_state, next_done):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs) logits, value, valid, next_lstm_state = agent(next_obs, next_lstm_state, next_done)
return logits, value return logits, value, next_lstm_state
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile) # predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device) obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
with torch.no_grad(): with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (obs, next_lstm_state, next_done), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
...@@ -353,6 +361,11 @@ def main(): ...@@ -353,6 +361,11 @@ def main():
next_to_play_ = info["to_play"] next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_) next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
ai_player1_ = np.concatenate([ ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
...@@ -363,6 +376,7 @@ def main(): ...@@ -363,6 +376,7 @@ def main():
next_value2 = 0 next_value2 = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
initial_lstm_state = (next_lstm_state[0].clone(), next_lstm_state[1].clone())
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations frac = 1.0 - (iteration - 1.0) / args.num_iterations
...@@ -383,7 +397,7 @@ def main(): ...@@ -383,7 +397,7 @@ def main():
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
logits, value = predict_step(traced_model, next_obs) logits, value, next_lstm_state = predict_step(traced_model, next_obs, next_lstm_state, next_done)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -438,8 +452,7 @@ def main(): ...@@ -438,8 +452,7 @@ def main():
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
# value = agent.get_value(next_obs).reshape(-1) value = traced_model(next_obs, next_lstm_state, next_done)[1].reshape(-1)
value = traced_model(next_obs)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device) advantages = torch.zeros_like(rewards).to(device)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
...@@ -535,24 +548,32 @@ def main(): ...@@ -535,24 +548,32 @@ def main():
} }
b_logprobs = logprobs.reshape(-1) b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape) b_actions = actions.reshape((-1,) + action_shape)
b_dones = dones.reshape(-1)
b_advantages = advantages.reshape(-1) b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1) b_returns = returns.reshape(-1)
b_values = values.reshape(-1) b_values = values.reshape(-1)
b_learns = learns.reshape(-1) b_learns = learns.reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) assert args.local_num_envs % args.num_minibatches == 0
envsperbatch = args.local_num_envs // args.num_minibatches # minibatch_size // num_steps
envinds = np.arange(args.local_num_envs)
flatinds = np.arange(args.local_batch_size).reshape(args.num_steps, args.local_num_envs)
clipfracs = [] clipfracs = []
for epoch in range(args.update_epochs): for epoch in range(args.update_epochs):
np.random.shuffle(b_inds) np.random.shuffle(envinds)
for start in range(0, args.local_batch_size, args.local_minibatch_size): for start in range(0, args.local_num_envs, envsperbatch):
end = start + args.local_minibatch_size end = start + envsperbatch
mb_inds = b_inds[start:end] mbenvinds = envinds[start:end]
mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
mb_obs = { mb_obs = {
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = train_step(
agent, scaler, mb_obs, (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
b_dones[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds]) b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size) reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm) nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
...@@ -606,16 +627,23 @@ def main(): ...@@ -606,16 +627,23 @@ def main():
episode_rewards = [] episode_rewards = []
eval_win_rates = [] eval_win_rates = []
e_obs = eval_envs.reset()[0] e_obs = eval_envs.reset()[0]
e_dones_ = np.zeros(local_eval_num_envs, dtype=np.bool_)
e_next_lstm_state = (
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
)
while True: while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8) e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = predict_step(traced_model, e_obs)[0] e_dones = to_tensor(e_dones_, dtype=torch.bool)
e_logits, _, e_next_lstm_state = predict_step(traced_model, e_obs, e_next_lstm_state, e_dones)
e_probs = torch.softmax(e_logits, dim=-1) e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy() e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1) e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions) e_obs, e_rewards, e_dones_, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones): for idx, d in enumerate(e_dones_):
if d: if d:
episode_length = e_info['l'][idx] episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx] episode_reward = e_info['r'][idx]
......
...@@ -8,7 +8,6 @@ from typing import Literal, Optional ...@@ -8,7 +8,6 @@ from typing import Literal, Optional
import ygoenv import ygoenv
import numpy as np import numpy as np
import optree
import tyro import tyro
import torch import torch
...@@ -19,10 +18,12 @@ import torch.distributed as dist ...@@ -19,10 +18,12 @@ import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, mp_start, setup from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_self
from ygoai.rl.eval import evaluate
@dataclass @dataclass
...@@ -39,7 +40,7 @@ class Args: ...@@ -39,7 +40,7 @@ class Args:
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
"""the deck file for the first player""" """the deck file for the first player"""
...@@ -47,21 +48,23 @@ class Args: ...@@ -47,21 +48,23 @@ class Args:
"""the deck file for the second player""" """the deck file for the second player"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
embedding_file: Optional[str] = "embeddings_en.npy" embedding_file: Optional[str] = None
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot" play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'""" """the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 1000000000 total_timesteps: int = 2000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 2.5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
...@@ -76,7 +79,7 @@ class Args: ...@@ -76,7 +79,7 @@ class Args:
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
update_win_rate: float = 0.6 update_win_rate: float = 0.55
"""the required win rate to update the agent""" """the required win rate to update the agent"""
update_return: float = 0.1 update_return: float = 0.1
"""the required return to update the agent""" """the required return to update the agent"""
...@@ -99,9 +102,11 @@ class Args: ...@@ -99,9 +102,11 @@ class Args:
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None target_kl: Optional[float] = None
"""the target KL divergence threshold""" """the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl" backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training""" """the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -121,10 +126,10 @@ class Args: ...@@ -121,10 +126,10 @@ class Args:
"""the number of iterations to save the model""" """the number of iterations to save the model"""
log_p: float = 1.0 log_p: float = 1.0
"""the probability of logging""" """the probability of logging"""
port: int = 12356
"""the port to use for distributed training"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
local_batch_size: int = 0 local_batch_size: int = 0
...@@ -141,7 +146,12 @@ class Args: ...@@ -141,7 +146,12 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def run(local_rank, world_size): def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args) args = tyro.cli(Args)
args.world_size = world_size args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size args.local_num_envs = args.num_envs // args.world_size
...@@ -159,12 +169,12 @@ def run(local_rank, world_size): ...@@ -159,12 +169,12 @@ def run(local_rank, world_size):
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port) torchrun_setup(args.backend, local_rank)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None writer = None
if local_rank == 0: if rank == 0:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name)) writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text( writer.add_text(
...@@ -178,10 +188,10 @@ def run(local_rank, world_size): ...@@ -178,10 +188,10 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: seeding # TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker # CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank args.seed += rank
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank) torch.manual_seed(args.seed - rank)
if args.torch_deterministic: if args.torch_deterministic:
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
else: else:
...@@ -189,7 +199,7 @@ def run(local_rank, world_size): ...@@ -189,7 +199,7 @@ def run(local_rank, world_size):
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu") device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
...@@ -210,7 +220,7 @@ def run(local_rank, world_size): ...@@ -210,7 +220,7 @@ def run(local_rank, world_size):
obs_space = envs.observation_space obs_space = envs.observation_space
action_shape = envs.action_space.shape action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
print(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size local_eval_episodes = args.eval_episodes // args.world_size
...@@ -238,99 +248,47 @@ def run(local_rank, world_size): ...@@ -238,99 +248,47 @@ def run(local_rank, world_size):
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file: if args.embedding_file:
agent1.load_embeddings(embeddings) agent.freeze_embeddings()
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2.load_state_dict(agent1.state_dict())
optim_params = list(agent1.parameters()) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid): agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
x = x.masked_fill(~valid, 0) agent_t.eval()
return x.sum() / valid.float().sum() agent_t.load_state_dict(agent.state_dict())
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad(): def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent1: Agent, agent2: Agent, next_obs, learn):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits1, value1, valid = agent1(next_obs) logits, value, valid = agent(next_obs)
logits2, value2, valid = agent2(next_obs) logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits1, logits2) logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value1, value2) value = torch.where(learn[:, None], value, value_t)
return logits, value return logits, value
def eval_step(agent: Agent, next_obs): from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
logits = agent.get_logit(next_obs) traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
return logits traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
if args.compile:
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
predict_step = torch.compile(predict_step, mode='default')
# eval_step = torch.compile(eval_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device) obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
...@@ -349,16 +307,16 @@ def run(local_rank, world_size): ...@@ -349,16 +307,16 @@ def run(local_rank, world_size):
warmup_steps = 0 warmup_steps = 0
start_time = time.time() start_time = time.time()
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8) next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"] next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_) next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player_ = np.concatenate([ ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player_) np.random.shuffle(ai_player1_)
ai_player = to_tensor(ai_player_, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0 next_value = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
...@@ -377,11 +335,11 @@ def run(local_rank, world_size): ...@@ -377,11 +335,11 @@ def run(local_rank, world_size):
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
learn = next_to_play == ai_player learn = next_to_play == ai_player1
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
logits, value = predict_step(agent1, agent2, next_obs, learn) logits, value = predict_step(traced_model, traced_model_t, next_obs, learn)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -393,23 +351,24 @@ def run(local_rank, world_size): ...@@ -393,23 +351,24 @@ def run(local_rank, world_size):
action = action.cpu().numpy() action = action.cpu().numpy()
model_time += time.time() - _start model_time += time.time() - _start
next_value = torch.where(learn, value, next_value) * (1 - next_done.float()) next_nonterminal = 1 - next_done.float()
next_value = torch.where(learn, value, next_value) * next_nonterminal
_start = time.time() _start = time.time()
to_play = next_to_play_ to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action) next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"] next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_) next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward) rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool) next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
if not writer: if not writer:
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player_[idx] else -1 pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
...@@ -421,65 +380,27 @@ def run(local_rank, world_size): ...@@ -421,65 +380,27 @@ def run(local_rank, world_size):
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n: if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step) writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True) if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = agent1.get_value(next_obs).reshape(-1) value = traced_model(next_obs)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device) value_t = traced_model_t(next_obs)[1].reshape(-1)
nextvalues = torch.where(next_to_play == ai_player, value, next_value) value = torch.where(next_to_play == ai_player1, value, value_t)
done_used = torch.zeros_like(next_done, dtype=torch.bool) nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
reward = 0 advantages = bootstrap_value_self(
lastgaelam = 0 values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
for t in reversed(range(args.num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward = rewards[t]
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# if not done_used:
# reward = reward
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn = learns[t]
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn.int() - 0.5)
reward = torch.where(next_done, rewards[t] * sp, torch.where(learn & done_used, 0, reward))
real_done = next_done | ~done_used
nextvalues = torch.where(real_done, 0, nextvalues)
lastgaelam = torch.where(real_done, 0, lastgaelam)
done_used = torch.where(
next_done, learn, torch.where(learn & ~done_used, True, done_used))
delta = reward + args.gamma * nextvalues - values[t]
advantages[t] = lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
nextvalues = torch.where(learn, values[t], nextvalues)
lastgaelam = torch.where(learn, lastgaelam_, lastgaelam)
returns = advantages + values returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time() _start = time.time()
# flatten the batch # flatten the batch
...@@ -506,8 +427,8 @@ def run(local_rank, world_size): ...@@ -506,8 +427,8 @@ def run(local_rank, world_size):
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent1, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds]) b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, args.world_size) reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm) nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
...@@ -519,18 +440,17 @@ def run(local_rank, world_size): ...@@ -519,18 +440,17 @@ def run(local_rank, world_size):
train_time = time.time() - _start train_time = time.time() - _start
print(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}", flush=True) if local_rank == 0:
# if local_rank == 0: fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
# print(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
...@@ -549,10 +469,12 @@ def run(local_rank, world_size): ...@@ -549,10 +469,12 @@ def run(local_rank, world_size):
start_time = time.time() start_time = time.time()
warmup_steps = global_step warmup_steps = global_step
if iteration > SPS_warmup_iters: if iteration > SPS_warmup_iters:
print("SPS:", SPS) if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if local_rank == 0: if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device) should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else: else:
...@@ -561,62 +483,42 @@ def run(local_rank, world_size): ...@@ -561,62 +483,42 @@ def run(local_rank, world_size):
dist.all_reduce(should_update, op=dist.ReduceOp.SUM) dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0 should_update = should_update.item() > 0
if should_update: if should_update:
agent2.load_state_dict(agent1.state_dict()) agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1 version += 1
if local_rank == 0: if rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}") print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear() avg_win_rates.clear()
avg_ep_returns.clear() avg_ep_returns.clear()
_start = time.time() _start = time.time()
episode_lengths = [] eval_return = evaluate(
episode_rewards = [] eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)
eval_win_rates = [] eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
e_obs = eval_envs.reset()[0]
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = eval_step(agent1, e_obs)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics # sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0: eval_return = eval_stats.cpu().numpy()
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy() if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step) writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step) if local_rank == 0:
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return}, eval_ep_len={eval_ep_len}, eval_win_rate={eval_win_rate}") fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1: if args.world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
if local_rank == 0: if rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close() writer.close()
if __name__ == "__main__": if __name__ == "__main__":
mp_start(run) main()
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_self
from ygoai.rl.eval import evaluate
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file:
embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
agent = torch.compile(agent, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
agent.eval()
model_time = 0
env_time = 0
collect_start = time.time()
for step in range(0, args.num_steps):
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
_start = time.time()
logits, value = predict_step(agent, traced_model_t, next_obs, learn)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value = torch.where(learn, value, next_value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = agent(next_obs)[1].reshape(-1)
value_t = traced_model_t(next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
advantages = bootstrap_value_self(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start
agent.train()
_start = time.time()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
_start = time.time()
agent.eval()
eval_return = evaluate(
eval_envs, agent, local_eval_episodes, device, args.fp16_eval)
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
main()
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical
def bytes_to_bin(x, points, intervals): def bytes_to_bin(x, points, intervals):
...@@ -18,7 +17,6 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24): ...@@ -18,7 +17,6 @@ def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0) intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0)
return points, intervals return points, intervals
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
...@@ -138,7 +136,7 @@ class Encoder(nn.Module): ...@@ -138,7 +136,7 @@ class Encoder(nn.Module):
self.action_history_net = nn.ModuleList([ self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer( nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False) c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers) for i in range(num_history_action_layers)
]) ])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False) self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
...@@ -335,307 +333,6 @@ class Encoder(nn.Module): ...@@ -335,307 +333,6 @@ class Encoder(nn.Module):
return f_actions, f_state, mask, valid return f_actions, f_state, mask, valid
class Encoder1(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.seq_embed = nn.Embedding(76, c)
self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine)
linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias)
c_num = c // 8
n_bins = 32
self.num_fc = nn.Sequential(
linear(n_bins, c_num),
nn.ReLU(),
)
bin_points, bin_intervals = make_bin_params(n_bins=n_bins)
self.bin_points = nn.Parameter(bin_points, requires_grad=False)
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int):
n_embed, embed_dim = embedding_shape, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
self.id_embed = nn.Embedding(n_embed, embed_dim)
self.id_fc_emb = linear(1024, c // 4)
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
self.owner_embed = nn.Embedding(2, c // 16)
self.position_embed = nn.Embedding(9, c // 16 * 2)
self.overley_embed = nn.Embedding(2, c // 16)
self.attribute_embed = nn.Embedding(8, c // 16)
self.race_embed = nn.Embedding(27, c // 16)
self.level_embed = nn.Embedding(14, c // 16)
self.counter_embed = nn.Embedding(16, c // 16)
self.type_fc_emb = linear(25, c // 16 * 2)
self.atk_fc_emb = linear(c_num, c // 16)
self.def_fc_emb = linear(c_num, c // 16)
self.feat_norm = nn.LayerNorm(c // 4 * 3, elementwise_affine=affine)
self.na_card_embed = nn.Parameter(torch.randn(1, c) * 0.02, requires_grad=True)
num_heads = max(2, c // 128)
self.card_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_card_layers)
])
self.card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.turn_embed = nn.Embedding(20, c // 8)
self.phase_embed = nn.Embedding(11, c // 8)
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.global_norm_pre = nn.LayerNorm(c, elementwise_affine=affine)
self.global_net = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.global_norm = nn.LayerNorm(c, elementwise_affine=False)
divisor = 8
self.a_msg_embed = nn.Embedding(30, c // divisor)
self.a_act_embed = nn.Embedding(13, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
# TODO: maybe same embedding as attribute_embed
self.a_attrib_embed = nn.Embedding(10, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.a_card_proj = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
num_heads = max(2, c // 128)
self.action_card_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.init_embeddings()
def init_embeddings(self, scale=0.0001):
for n, m in self.named_modules():
if isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -scale, scale)
elif n in ["atk_fc_emb", "def_fc_emb"]:
nn.init.uniform_(m.weight, -scale * 10, scale * 10)
elif n in ["lp_fc_emb", "oppo_lp_fc_emb"]:
nn.init.uniform_(m.weight, -scale, scale)
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings):
weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings)
def freeze_embeddings(self):
self.id_embed.weight.requires_grad = False
def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
def encode_action_(self, x):
x_a_msg = self.a_msg_embed(x[:, :, 0])
x_a_act = self.a_act_embed(x[:, :, 1])
x_a_yesno = self.a_yesno_embed(x[:, :, 2])
x_a_phase = self.a_phase_embed(x[:, :, 3])
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_number = self.a_number_embed(x[:, :, 7])
x_a_place = self.a_place_embed(x[:, :, 8])
x_a_attrib = self.a_attrib_embed(x[:, :, 9])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
m = c // 2
spec_index = x.view(b, n, m, 2)
spec_index = spec_index[..., 0] * 256 + spec_index[..., 1]
mask = spec_index != 0
mask[:, :, 0] = True
spec_index = spec_index.view(b, -1)
B = torch.arange(b, device=spec_index.device)
f_a_actions = f_cards[B[:, None], spec_index]
f_a_actions = f_a_actions.view(b, n, m, -1)
f_a_actions = (f_a_actions * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return f_a_actions
def get_h_action_card_(self, x):
b, n, _ = x.shape
x_ids = x.view(b, n, -1, 2)
x_ids = x_ids[..., 0] * 256 + x_ids[..., 1]
mask = x_ids != 0
mask[:, :, 0] = True
x_ids = self.id_embed(x_ids)
x_ids = self.h_id_fc_emb(x_ids)
x_ids = (x_ids * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return x_ids
def encode_card_id(self, x):
x_id = self.id_embed(x)
x_id = self.id_fc_emb(x_id)
x_id = self.id_norm(x_id)
return x_id
def encode_card_feat1(self, x1):
x_owner = self.owner_embed(x1[:, :, 2])
x_position = self.position_embed(x1[:, :, 3])
x_overley = self.overley_embed(x1[:, :, 4])
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
x_counter = self.counter_embed(x1[:, :, 8])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level, x_counter
def encode_card_feat2(self, x2):
x_atk = self.num_transform(x2[:, :, 0:2])
x_atk = self.atk_fc_emb(x_atk)
x_def = self.num_transform(x2[:, :, 2:4])
x_def = self.def_fc_emb(x_def)
x_type = self.type_fc_emb(x2[:, :, 4:])
return x_atk, x_def, x_type
def encode_global(self, x):
x_global_1 = x[:, :4].float()
x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:-1].long()
x_g_turn = self.turn_embed(x_global_2[:, 0])
x_g_phase = self.phase_embed(x_global_2[:, 1])
x_g_if_first = self.if_first_embed(x_global_2[:, 2])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3])
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1)
return x_global
def forward(self, x):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, 2:11].long()
x_cards_2 = x_cards[:, :, 11:].to(torch.float32)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 1]))
x_feat1 = self.encode_card_feat1(x_cards_1)
x_feat2 = self.encode_card_feat2(x_cards_2)
x_feat = torch.cat([*x_feat1, *x_feat2], dim=-1)
x_feat = self.feat_norm(x_feat)
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
x_global = self.global_norm_pre(x_global)
f_global = x_global + self.global_net(x_global)
f_global = self.global_norm(f_global)
f_cards = f_cards + f_global.unsqueeze(1)
x_actions = x_actions.long()
max_multi_select = (x_actions.shape[-1] - 9) // 2
mo = max_multi_select * 2
f_a_cards = self.get_action_card_(x_actions[..., :mo], f_cards)
f_a_cards = f_a_cards + self.a_card_proj(self.a_card_norm(f_a_cards))
x_a_feats = self.encode_action_(x_actions[..., mo:])
x_a_feats = torch.cat(x_a_feats, dim=-1)
f_actions = f_a_cards + self.a_feat_norm(x_a_feats)
mask = x_actions[:, :, mo] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :mo])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions)
f_s_cards_global = f_cards.mean(dim=1)
c_mask = 1 - mask.unsqueeze(-1).float()
f_s_actions_ha = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, channels, use_transformer=False): def __init__(self, channels, use_transformer=False):
......
import torch
import torch.nn as nn
def bytes_to_bin(x, points, intervals):
x = x[..., 0] * 256 + x[..., 1]
x = x.unsqueeze(-1)
return torch.clamp((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = torch.linspace(0, x_max1, sig_bins + 1, dtype=torch.float32)[1:]
points2 = torch.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=torch.float32)[1:]
points = torch.cat([points1, points2], dim=0)
intervals = torch.cat([points[0:1], points[1:] - points[:-1]], dim=0)
return points, intervals
class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.seq_embed = nn.Embedding(76, c)
self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine)
linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias)
c_num = c // 8
n_bins = 32
self.num_fc = nn.Sequential(
linear(n_bins, c_num),
nn.ReLU(),
)
bin_points, bin_intervals = make_bin_params(n_bins=n_bins)
self.bin_points = nn.Parameter(bin_points, requires_grad=False)
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
self.count_embed = nn.Embedding(100, c // 16)
self.hand_count_embed = nn.Embedding(100, c // 16)
if embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int):
n_embed, embed_dim = embedding_shape, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
self.id_embed = nn.Embedding(n_embed, embed_dim)
self.id_fc_emb = linear(1024, c // 4)
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
self.owner_embed = nn.Embedding(2, c // 16)
self.position_embed = nn.Embedding(9, c // 16 * 2)
self.overley_embed = nn.Embedding(2, c // 16)
self.attribute_embed = nn.Embedding(8, c // 16)
self.race_embed = nn.Embedding(27, c // 16)
self.level_embed = nn.Embedding(14, c // 16)
self.counter_embed = nn.Embedding(16, c // 16)
self.type_fc_emb = linear(25, c // 16 * 2)
self.atk_fc_emb = linear(c_num, c // 16)
self.def_fc_emb = linear(c_num, c // 16)
self.feat_norm = nn.LayerNorm(c // 4 * 3, elementwise_affine=affine)
self.na_card_embed = nn.Parameter(torch.randn(1, c) * 0.02, requires_grad=True)
num_heads = max(2, c // 128)
self.card_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_card_layers)
])
self.card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.turn_embed = nn.Embedding(20, c // 8)
self.phase_embed = nn.Embedding(11, c // 8)
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.my_deck_fc_emb = linear(1024, c // 4)
self.global_norm_pre = nn.LayerNorm(c * 2, elementwise_affine=affine)
self.global_net = nn.Sequential(
nn.Linear(c * 2, c * 2),
nn.ReLU(),
nn.Linear(c * 2, c * 2),
)
self.global_proj = nn.Linear(c * 2, c)
self.global_norm = nn.LayerNorm(c, elementwise_affine=False)
divisor = 8
self.a_msg_embed = nn.Embedding(30, c // divisor)
self.a_act_embed = nn.Embedding(13, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
# TODO: maybe same embedding as attribute_embed
self.a_attrib_embed = nn.Embedding(10, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.a_card_proj = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
num_heads = max(2, c // 128)
self.action_card_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_history_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.init_embeddings()
def init_embeddings(self, scale=0.0001):
for n, m in self.named_modules():
if isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -scale, scale)
elif n in ["atk_fc_emb", "def_fc_emb"]:
nn.init.uniform_(m.weight, -scale * 10, scale * 10)
elif n in ["lp_fc_emb", "oppo_lp_fc_emb"]:
nn.init.uniform_(m.weight, -scale, scale)
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings):
weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings)
def freeze_embeddings(self):
self.id_embed.weight.requires_grad = False
def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
def encode_action_(self, x):
x_a_msg = self.a_msg_embed(x[:, :, 0])
x_a_act = self.a_act_embed(x[:, :, 1])
x_a_yesno = self.a_yesno_embed(x[:, :, 2])
x_a_phase = self.a_phase_embed(x[:, :, 3])
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_number = self.a_number_embed(x[:, :, 7])
x_a_place = self.a_place_embed(x[:, :, 8])
x_a_attrib = self.a_attrib_embed(x[:, :, 9])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
m = c // 2
spec_index = x.view(b, n, m, 2)
spec_index = spec_index[..., 0] * 256 + spec_index[..., 1]
mask = spec_index != 0
mask[:, :, 0] = True
spec_index = spec_index.view(b, -1)
B = torch.arange(b, device=spec_index.device)
f_a_actions = f_cards[B[:, None], spec_index]
f_a_actions = f_a_actions.view(b, n, m, -1)
f_a_actions = (f_a_actions * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return f_a_actions
def get_h_action_card_(self, x):
b, n, _ = x.shape
x_ids = x.view(b, n, -1, 2)
x_ids = x_ids[..., 0] * 256 + x_ids[..., 1]
mask = x_ids != 0
mask[:, :, 0] = True
x_ids = self.id_embed(x_ids)
x_ids = self.h_id_fc_emb(x_ids)
x_ids = (x_ids * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return x_ids
def encode_card_id(self, x):
x_id = self.id_embed(x)
x_id = self.id_fc_emb(x_id)
x_id = self.id_norm(x_id)
return x_id
def encode_card_feat1(self, x1):
x_owner = self.owner_embed(x1[:, :, 2])
x_position = self.position_embed(x1[:, :, 3])
x_overley = self.overley_embed(x1[:, :, 4])
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
x_counter = self.counter_embed(x1[:, :, 8])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level, x_counter
def encode_card_feat2(self, x2):
x_atk = self.num_transform(x2[:, :, 0:2])
x_atk = self.atk_fc_emb(x_atk)
x_def = self.num_transform(x2[:, :, 2:4])
x_def = self.def_fc_emb(x_def)
x_type = self.type_fc_emb(x2[:, :, 4:])
return x_atk, x_def, x_type
def encode_global(self, x):
x_global_1 = x[:, :4].float()
x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:8].long()
x_g_turn = self.turn_embed(x_global_2[:, 0])
x_g_phase = self.phase_embed(x_global_2[:, 1])
x_g_if_first = self.if_first_embed(x_global_2[:, 2])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3])
x_global_3 = x[:, 8:22].long()
x_g_cs = self.count_embed(x_global_3).flatten(1)
x_g_my_hand_c = self.hand_count_embed(x_global_3[:, 1])
x_g_op_hand_c = self.hand_count_embed(x_global_3[:, 8])
x_global = torch.cat([
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], dim=-1)
return x_global
def forward(self, x):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, 2:11].long()
x_cards_2 = x_cards[:, :, 11:].to(torch.float32)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 1]))
x_feat1 = self.encode_card_feat1(x_cards_1)
x_feat2 = self.encode_card_feat2(x_cards_2)
x_feat = torch.cat([*x_feat1, *x_feat2], dim=-1)
x_feat = self.feat_norm(x_feat)
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
x_global = self.global_norm_pre(x_global)
f_global = x_global + self.global_net(x_global)
f_global = self.global_proj(f_global)
f_global = self.global_norm(f_global)
f_cards = f_cards + f_global.unsqueeze(1)
x_actions = x_actions.long()
max_multi_select = (x_actions.shape[-1] - 9) // 2
mo = max_multi_select * 2
f_a_cards = self.get_action_card_(x_actions[..., :mo], f_cards)
f_a_cards = f_a_cards + self.a_card_proj(self.a_card_norm(f_a_cards))
x_a_feats = self.encode_action_(x_actions[..., mo:])
x_a_feats = torch.cat(x_a_feats, dim=-1)
f_actions = f_a_cards + self.a_feat_norm(x_a_feats)
mask = x_actions[:, :, mo] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :mo])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions)
f_s_cards_global = f_cards.mean(dim=1)
c_mask = 1 - mask.unsqueeze(-1).float()
f_s_actions_ha = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid
class Actor(nn.Module):
def __init__(self, channels, use_transformer=False):
super(Actor, self).__init__()
c = channels
self.state_proj = nn.Sequential(
nn.Linear(c * 2, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.use_transformer = use_transformer
if use_transformer:
self.transformer = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=True)
self.head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def forward(self, f_actions, h_state, mask):
f_state = self.state_proj(h_state)
# TODO: maybe token concat
f_actions = f_actions + f_state.unsqueeze(1)
if self.use_transformer:
f_actions = self.transformer(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True, num_lstm_layers=1):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.lstm = nn.LSTM(c * 2, c * 2, num_lstm_layers)
self.actor = Actor(c, a_trans)
self.critic = nn.Sequential(
nn.Linear(c * 2, c // 2),
nn.ReLU(),
nn.Linear(c // 2, 1),
)
self.init_lstm()
def init_lstm(self):
for name, param in self.lstm.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
nn.init.orthogonal_(param, 1.0)
def load_embeddings(self, embeddings):
self.encoder.load_embeddings(embeddings)
def freeze_embeddings(self):
self.encoder.freeze_embeddings()
# def get_logit(self, x):
# f_actions, f_state, mask, valid = self.encoder(x)
# return self.actor(f_actions, mask)
# def get_value(self, x):
# f_actions, f_state, mask, valid = self.encoder(x)
# return self.critic(f_state)
def encode_lstm(self, hidden, lstm_state, done):
batch_size = lstm_state[0].shape[1]
hidden = hidden.reshape((-1, batch_size, self.lstm.input_size))
new_hidden, lstm_state = self.lstm(hidden, lstm_state)
# not_done = (~done.reshape((-1, batch_size))).float()
# new_hidden = []
# for i in range(hidden.shape[0]):
# h, lstm_state = self.lstm(
# hidden[i].unsqueeze(0),
# (
# not_done[i].view(1, -1, 1) * lstm_state[0],
# not_done[i].view(1, -1, 1) * lstm_state[1],
# ),
# )
# new_hidden += [h]
# new_hidden = torch.cat(new_hidden)
new_hidden = torch.flatten(new_hidden, 0, 1)
return new_hidden, lstm_state
def forward(self, x, lstm_state, done):
f_actions, f_state, mask, valid = self.encoder(x)
h_state, lstm_state = self.encode_lstm(f_state, lstm_state, done)
logits = self.actor(f_actions, h_state, mask)
return logits, self.critic(h_state), valid, lstm_state
import numpy as np
import torch
from torch.cuda.amp import autocast
from ygoai.rl.utils import to_tensor
def evaluate(envs, model, num_episodes, device, fp16_eval=False):
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
while True:
obs = to_tensor(obs, device, dtype=torch.uint8)
with torch.no_grad():
with autocast(enabled=fp16_eval):
logits = model(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
actions = probs.argmax(axis=1)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if d:
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
import torch
from torch.distributions import Categorical
from torch.cuda.amp import autocast
from ygoai.rl.utils import masked_normalize, masked_mean
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
if not args.learn_opponent:
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
done_used = torch.ones_like(next_done, dtype=torch.bool)
reward = 0
lastgaelam = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward = rewards[t]
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# if not done_used:
# reward = reward
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn = learns[t]
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn.int() - 0.5)
reward = torch.where(next_done, rewards[t] * sp, torch.where(learn & done_used, 0, reward))
real_done = next_done | ~done_used
nextvalues = torch.where(real_done, 0, nextvalues)
lastgaelam = torch.where(real_done, 0, lastgaelam)
done_used = torch.where(
next_done, learn, torch.where(learn & ~done_used, True, done_used))
delta = reward + gamma * nextvalues - values[t]
lastgaelam_ = delta + gamma * gae_lambda * lastgaelam
advantages[t] = lastgaelam_
nextvalues = torch.where(learn, values[t], nextvalues)
lastgaelam = torch.where(learn, lastgaelam_, lastgaelam)
return advantages
def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# delta1 = reward1 + args.gamma * nextvalues1 - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# advantages[t] = lastgaelam1_
# nextvalues1 = values[t]
# lastgaelam1 = lastgaelam_
# else:
# if dones[t+1]:
# reward2 = rewards[t]
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
#
# reward1 = -rewards[t]
# done_used1 = False
# else:
# if not done_used2:
# reward2 = reward2
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
# else:
# reward2 = rewards[t]
# reward1 = reward1
# delta2 = reward2 + args.gamma * nextvalues2 - values[t]
# lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
# advantages[t] = lastgaelam2_
# nextvalues2 = values[t]
# lastgaelam2 = lastgaelam_
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
\ No newline at end of file
...@@ -2,6 +2,8 @@ import re ...@@ -2,6 +2,8 @@ import re
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
import optree
import torch
class RecordEpisodeStatistics(gym.Wrapper): class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
...@@ -84,3 +86,21 @@ class Elo: ...@@ -84,3 +86,21 @@ class Elo:
def expect_result(self, p0, p1): def expect_result(self, p0, p1):
exp = (p0 - p1) / 400.0 exp = (p0 - p1) / 400.0
return 1 / ((10.0 ** (exp)) + 1) return 1 / ((10.0 ** (exp)) + 1)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def to_tensor(x, device, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment