Commit 11261948 authored by Biluo Shen's avatar Biluo Shen

(WIP) OSFP

parent 4d07e48e
*.pt
*.pkl
# Xmake cache
.xmake/
......
......@@ -140,8 +140,8 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
......
......@@ -154,7 +154,7 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
......
......@@ -5,6 +5,7 @@ from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import tyro
......@@ -247,7 +248,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
......@@ -274,9 +275,9 @@ def main():
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
else:
......@@ -389,7 +390,7 @@ def main():
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1)
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
......@@ -403,7 +404,7 @@ def main():
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = traced_model(v_obs)[1].reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
......@@ -420,7 +421,7 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = learns[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
# Optimizing the policy and value network
......
......@@ -243,7 +243,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
......
This diff is collapsed.
......@@ -102,8 +102,10 @@ class Args:
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
......@@ -161,6 +163,9 @@ def main():
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
......@@ -248,7 +253,7 @@ def main():
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent.eval()
if args.checkpoint:
......@@ -260,22 +265,19 @@ def main():
if args.embedding_file:
agent.freeze_embeddings()
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
agent_t = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
def predict_step(agent: Agent, agent_t: Agent, next_obs, learn):
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
logits_t, value_t, valid = agent_t(next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
return logits, value
from ygoai.rl.ppo import train_step
......@@ -289,15 +291,18 @@ def main():
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
......@@ -318,6 +323,7 @@ def main():
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
......@@ -329,7 +335,7 @@ def main():
model_time = 0
env_time = 0
collect_start = time.time()
for step in range(0, args.num_steps):
while step < args.collect_length:
global_step += args.num_envs
for key in obs:
......@@ -339,7 +345,10 @@ def main():
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, traced_model_t, next_obs, learn)
logits, value = predict_step(traced_model, next_obs)
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -362,6 +371,7 @@ def main():
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer:
continue
......@@ -390,6 +400,8 @@ def main():
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
......@@ -397,23 +409,36 @@ def main():
value_t = traced_model_t(next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value)
if step > 0 and iteration != 1:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_self(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda)
returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -437,7 +462,14 @@ def main():
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start
if local_rank == 0:
......@@ -497,7 +529,7 @@ def main():
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
......
......@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
......@@ -165,11 +163,17 @@ class Encoder(nn.Module):
for i in range(num_action_layers)
])
self.action_history_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_history_action_layers)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
......@@ -287,6 +291,7 @@ class Encoder(nn.Module):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
x_cards_1 = x_cards[:, :, :12].long()
x_cards_2 = x_cards[:, :, 12:].to(torch.float32)
......@@ -294,7 +299,10 @@ class Encoder(nn.Module):
x_id = self.encode_card_id(x_cards_1[:, :, :2])
x_id = self.id_norm(x_id)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 2]))
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask[:, 0] = False
f_loc = self.loc_norm(self.loc_embed(x_loc))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3]))
x_feat1 = self.encode_card_feat1(x_cards_1)
......@@ -306,11 +314,14 @@ class Encoder(nn.Module):
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
for layer in self.card_net:
# f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_na_card = self.na_card_embed.expand(batch_size, -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
# TODO: we can't use it because cudagraph says complex memory
# c_mask = torch.cat([torch.zeros(batch_size, 1, dtype=c_mask.dtype, device=c_mask.device), c_mask], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
......@@ -334,21 +345,24 @@ class Encoder(nn.Module):
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = layer(
f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask[:, 0] = False
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.history_action_pe(f_h_actions)
for layer in self.history_action_net:
f_h_actions = layer(f_h_actions, src_key_padding_mask=h_mask)
for layer in self.action_history_net:
f_actions = layer(
f_actions, f_h_actions, tgt_key_padding_mask=mask, memory_key_padding_mask=h_mask)
f_actions = self.action_norm(f_actions)
......@@ -385,13 +399,12 @@ class Actor(nn.Module):
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False,
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
channels, num_card_layers, num_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = Actor(c, a_trans)
......
......@@ -11,8 +11,7 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
if not args.learn_opponent:
valid = torch.logical_and(valid, mb_learns)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
......
......@@ -1870,10 +1870,10 @@ private:
std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
SpecIndex spec2index;
std::vector<int> loc_n_cards;
int offset = 0;
for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1;
int offset = opponent ? spec_.config["max_cards"_] : 0;
std::vector<std::pair<uint8_t, bool>> configs = {
{LOCATION_DECK, true}, {LOCATION_HAND, true},
{LOCATION_MZONE, false}, {LOCATION_SZONE, false},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment