Commit 7e87d523 authored by biluo.shen's avatar biluo.shen

Refactor

parent 3ac7492c
......@@ -24,7 +24,22 @@ YGO Agent is a project to create a Yu-Gi-Oh! AI using deep learning (LLMs, RL).
## Usage
TODO
### Serialize agent
After training, we can serialize the trained agent model to a file for later use without keeping source code of the model. The serialized model file will end with `.ptj` (PyTorch JIT) extension.
```bash
python -u eval.py --agent --checkpoint checkpoints/1234_1000M.pt --num_embeddings 999 --convert --optimize
```
### Battle between two agents
We can use `battle.py` to let two agents play against each other and find out which one is better.
```bash
python -u battle.py --deck ../assets/deck --checkpoint1 checkpoints/1234_1000M.ptj --checkpoint2 checkpoints/9876_100M.ptj --num-episodes=256 --num_envs=32 --seed 0
```
### Running
TODO
......
# LSTM Implementations
## Original PPO + LSTM in CleanRL
```python
not_done = (~done.reshape((-1, batch_size))).float()
new_hidden = []
for i in range(hidden.shape[0]):
h, lstm_state = self.lstm(
hidden[i].unsqueeze(0),
(
not_done[i].view(1, -1, 1) * lstm_state[0],
not_done[i].view(1, -1, 1) * lstm_state[1],
),
)
new_hidden += [h]
new_hidden = torch.cat(new_hidden)
# new_hidden, lstm_state = self.lstm(hidden, lstm_state)
```
The length of the loop is the `num_steps` (typically 128), therefore it is slow (even with torch.compile). Compared with the original LSTM, the overall training time is 4x slower.
## Custom LSTM with triton
```python
```
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -149,7 +149,7 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
# agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
......
This diff is collapsed.
......@@ -20,7 +20,7 @@ from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.agent2 import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
......@@ -136,6 +136,8 @@ class Args:
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_minibatches: int = 0
"""the number of mini-batches (computed in runtime)"""
def main():
......@@ -151,6 +153,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
......@@ -240,7 +243,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
......@@ -268,9 +271,9 @@ def main():
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
def train_step(agent: Agent, scaler, mb_obs, lstm_state, mb_dones, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
logits, newvalue, valid, _ = agent(mb_obs, lstm_state, mb_dones)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
......@@ -315,18 +318,23 @@ def main():
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent: Agent, next_obs):
def predict_step(agent: Agent, next_obs, next_lstm_state, next_done):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
return logits, value
logits, value, valid, next_lstm_state = agent(next_obs, next_lstm_state, next_done)
return logits, value, next_lstm_state
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
traced_model = torch.jit.trace(agent, (obs, next_lstm_state, next_done), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
......@@ -353,6 +361,11 @@ def main():
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
......@@ -363,6 +376,7 @@ def main():
next_value2 = 0
for iteration in range(1, args.num_iterations + 1):
initial_lstm_state = (next_lstm_state[0].clone(), next_lstm_state[1].clone())
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
......@@ -383,7 +397,7 @@ def main():
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
logits, value, next_lstm_state = predict_step(traced_model, next_obs, next_lstm_state, next_done)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -438,8 +452,7 @@ def main():
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
# value = agent.get_value(next_obs).reshape(-1)
value = traced_model(next_obs)[1].reshape(-1)
value = traced_model(next_obs, next_lstm_state, next_done)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
......@@ -535,25 +548,33 @@ def main():
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_dones = dones.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
assert args.local_num_envs % args.num_minibatches == 0
envsperbatch = args.local_num_envs // args.num_minibatches # minibatch_size // num_steps
envinds = np.arange(args.local_num_envs)
flatinds = np.arange(args.local_batch_size).reshape(args.num_steps, args.local_num_envs)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
np.random.shuffle(envinds)
for start in range(0, args.local_num_envs, envsperbatch):
end = start + envsperbatch
mbenvinds = envinds[start:end]
mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = train_step(
agent, scaler, mb_obs, (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
b_dones[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
......@@ -606,16 +627,23 @@ def main():
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
e_dones_ = np.zeros(local_eval_num_envs, dtype=np.bool_)
e_next_lstm_state = (
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
)
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = predict_step(traced_model, e_obs)[0]
e_dones = to_tensor(e_dones_, dtype=torch.bool)
e_logits, _, e_next_lstm_state = predict_step(traced_model, e_obs, e_next_lstm_state, e_dones)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions)
e_obs, e_rewards, e_dones_, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones):
for idx, d in enumerate(e_dones_):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import numpy as np
import torch
from torch.cuda.amp import autocast
from ygoai.rl.utils import to_tensor
def evaluate(envs, model, num_episodes, device, fp16_eval=False):
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
while True:
obs = to_tensor(obs, device, dtype=torch.uint8)
with torch.no_grad():
with autocast(enabled=fp16_eval):
logits = model(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
actions = probs.argmax(axis=1)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if d:
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
import torch
from torch.distributions import Categorical
from torch.cuda.amp import autocast
from ygoai.rl.utils import masked_normalize, masked_mean
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
if not args.learn_opponent:
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value_self(values, rewards, dones, learns, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
done_used = torch.ones_like(next_done, dtype=torch.bool)
reward = 0
lastgaelam = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward = rewards[t]
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# if not done_used:
# reward = reward
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn = learns[t]
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn.int() - 0.5)
reward = torch.where(next_done, rewards[t] * sp, torch.where(learn & done_used, 0, reward))
real_done = next_done | ~done_used
nextvalues = torch.where(real_done, 0, nextvalues)
lastgaelam = torch.where(real_done, 0, lastgaelam)
done_used = torch.where(
next_done, learn, torch.where(learn & ~done_used, True, done_used))
delta = reward + gamma * nextvalues - values[t]
lastgaelam_ = delta + gamma * gae_lambda * lastgaelam
advantages[t] = lastgaelam_
nextvalues = torch.where(learn, values[t], nextvalues)
lastgaelam = torch.where(learn, lastgaelam_, lastgaelam)
return advantages
def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# delta1 = reward1 + args.gamma * nextvalues1 - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# advantages[t] = lastgaelam1_
# nextvalues1 = values[t]
# lastgaelam1 = lastgaelam_
# else:
# if dones[t+1]:
# reward2 = rewards[t]
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
#
# reward1 = -rewards[t]
# done_used1 = False
# else:
# if not done_used2:
# reward2 = reward2
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
# else:
# reward2 = rewards[t]
# reward1 = reward1
# delta2 = reward2 + args.gamma * nextvalues2 - values[t]
# lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
# advantages[t] = lastgaelam2_
# nextvalues2 = values[t]
# lastgaelam2 = lastgaelam_
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
\ No newline at end of file
......@@ -2,6 +2,8 @@ import re
import numpy as np
import gymnasium as gym
import optree
import torch
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
......@@ -83,4 +85,22 @@ class Elo:
def expect_result(self, p0, p1):
exp = (p0 - p1) / 400.0
return 1 / ((10.0 ** (exp)) + 1)
\ No newline at end of file
return 1 / ((10.0 ** (exp)) + 1)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def to_tensor(x, device, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment