Commit ad9980e5 authored by biluo.shen's avatar biluo.shen

Add battle

parent e5e5402a
edopro-core @ 8c623744
Subproject commit 8c6237444e294b730bce1eccc6fab2721b7cbea9
......@@ -65,7 +65,7 @@ class Args:
checkpoint2: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent"""
compile: bool = True
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
......@@ -130,33 +130,37 @@ if __name__ == "__main__":
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
if args.checkpoint1.endswith(".ptj"):
agent1 = torch.jit.load(args.checkpoint1)
agent2 = torch.jit.load(args.checkpoint2)
else:
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
else:
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
obs, infos = envs.reset()
next_to_play_ = infos['to_play']
......
......@@ -80,6 +80,9 @@ class Args:
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
......@@ -156,6 +159,21 @@ if __name__ == "__main__":
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
# Don't support dynamic shapes and very slow inference
raise NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent = torch.compile(agent, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
......@@ -164,6 +182,10 @@ if __name__ == "__main__":
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent)
if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
obs, infos = envs.reset()
next_to_play = infos['to_play']
......
......@@ -425,7 +425,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if iteration % args.save_interval == 0 or iteration == args.num_iterations:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......
......@@ -21,7 +21,7 @@ from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, mp_start, setup, fprint
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
......@@ -118,8 +118,6 @@ class Args:
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
port: int = 12356
"""the port to use for distributed training"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
......@@ -140,7 +138,12 @@ class Args:
"""the number of processes (computed in runtime)"""
def run(local_rank, world_size):
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
......@@ -158,12 +161,12 @@ def run(local_rank, world_size):
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port)
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if local_rank == 0:
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
......@@ -177,10 +180,10 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
......@@ -188,7 +191,7 @@ def run(local_rank, world_size):
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
......@@ -429,7 +432,8 @@ def run(local_rank, world_size):
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
fprint(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done
......@@ -561,16 +565,17 @@ def run(local_rank, world_size):
train_time = time.time() - _start
fprint(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......@@ -581,15 +586,17 @@ def run(local_rank, world_size):
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
......@@ -628,11 +635,12 @@ def run(local_rank, world_size):
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0:
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
......@@ -641,10 +649,10 @@ def run(local_rank, world_size):
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if local_rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth"))
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
mp_start(run)
main()
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file:
embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
return logits, value
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, dtype=next_to_play.dtype)
next_value1 = 0
next_value2 = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
collect_start = time.time()
agent.eval()
for step in range(0, args.num_steps):
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
env_time += time.time() - _start
rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool)
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
# value = agent.get_value(next_obs).reshape(-1)
value = traced_model(next_obs)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(args.num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# delta1 = reward1 + args.gamma * nextvalues1 - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# advantages[t] = lastgaelam1_
# nextvalues1 = values[t]
# lastgaelam1 = lastgaelam_
# else:
# if dones[t+1]:
# reward2 = rewards[t]
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
#
# reward1 = -rewards[t]
# done_used1 = False
# else:
# if not done_used2:
# reward2 = reward2
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
# else:
# reward2 = rewards[t]
# reward1 = reward1
# delta2 = reward2 + args.gamma * nextvalues2 - values[t]
# lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
# advantages[t] = lastgaelam2_
# nextvalues2 = values[t]
# lastgaelam2 = lastgaelam_
learn1 = learns[t]
learn2 = ~learn1
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + args.gamma * nextvalues1 - values[t]
delta2 = reward2 + args.gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time()
agent.train()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
episode_lengths = []
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = predict_step(traced_model, e_obs)[0]
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth"))
writer.close()
if __name__ == "__main__":
main()
......@@ -530,7 +530,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......@@ -564,7 +564,7 @@ def run(local_rank, world_size):
agent2.load_state_dict(agent1.state_dict())
version += 1
if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pth"))
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
......@@ -614,7 +614,7 @@ def run(local_rank, world_size):
dist.destroy_process_group()
envs.close()
if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth"))
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
......
......@@ -7,7 +7,7 @@ def bytes_to_bin(x, points, intervals):
x = x[..., 0] * 256 + x[..., 1]
x = x.unsqueeze(-1)
return torch.clamp((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
......@@ -334,6 +334,308 @@ class Encoder(nn.Module):
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid
class Encoder1(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
self.loc_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.seq_embed = nn.Embedding(76, c)
self.seq_norm = nn.LayerNorm(c, elementwise_affine=affine)
linear = lambda in_features, out_features: nn.Linear(in_features, out_features, bias=bias)
c_num = c // 8
n_bins = 32
self.num_fc = nn.Sequential(
linear(n_bins, c_num),
nn.ReLU(),
)
bin_points, bin_intervals = make_bin_params(n_bins=n_bins)
self.bin_points = nn.Parameter(bin_points, requires_grad=False)
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int):
n_embed, embed_dim = embedding_shape, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
self.id_embed = nn.Embedding(n_embed, embed_dim)
self.id_fc_emb = linear(1024, c // 4)
self.id_norm = nn.LayerNorm(c // 4, elementwise_affine=False)
self.owner_embed = nn.Embedding(2, c // 16)
self.position_embed = nn.Embedding(9, c // 16 * 2)
self.overley_embed = nn.Embedding(2, c // 16)
self.attribute_embed = nn.Embedding(8, c // 16)
self.race_embed = nn.Embedding(27, c // 16)
self.level_embed = nn.Embedding(14, c // 16)
self.counter_embed = nn.Embedding(16, c // 16)
self.type_fc_emb = linear(25, c // 16 * 2)
self.atk_fc_emb = linear(c_num, c // 16)
self.def_fc_emb = linear(c_num, c // 16)
self.feat_norm = nn.LayerNorm(c // 4 * 3, elementwise_affine=affine)
self.na_card_embed = nn.Parameter(torch.randn(1, c) * 0.02, requires_grad=True)
num_heads = max(2, c // 128)
self.card_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_card_layers)
])
self.card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.turn_embed = nn.Embedding(20, c // 8)
self.phase_embed = nn.Embedding(11, c // 8)
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.global_norm_pre = nn.LayerNorm(c, elementwise_affine=affine)
self.global_net = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.global_norm = nn.LayerNorm(c, elementwise_affine=False)
divisor = 8
self.a_msg_embed = nn.Embedding(30, c // divisor)
self.a_act_embed = nn.Embedding(13, c // divisor)
self.a_yesno_embed = nn.Embedding(3, c // divisor)
self.a_phase_embed = nn.Embedding(4, c // divisor)
self.a_cancel_finish_embed = nn.Embedding(3, c // divisor)
self.a_position_embed = nn.Embedding(9, c // divisor)
self.a_option_embed = nn.Embedding(6, c // divisor // 2)
self.a_number_embed = nn.Embedding(13, c // divisor // 2)
self.a_place_embed = nn.Embedding(31, c // divisor // 2)
# TODO: maybe same embedding as attribute_embed
self.a_attrib_embed = nn.Embedding(10, c // divisor // 2)
self.a_feat_norm = nn.LayerNorm(c, elementwise_affine=affine)
self.a_card_norm = nn.LayerNorm(c, elementwise_affine=False)
self.a_card_proj = nn.Sequential(
nn.Linear(c, c),
nn.ReLU(),
nn.Linear(c, c),
)
self.h_id_fc_emb = linear(1024, c)
self.h_id_norm = nn.LayerNorm(c, elementwise_affine=False)
self.h_a_feat_norm = nn.LayerNorm(c, elementwise_affine=False)
num_heads = max(2, c // 128)
self.action_card_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
self.init_embeddings()
def init_embeddings(self, scale=0.0001):
for n, m in self.named_modules():
if isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -scale, scale)
elif n in ["atk_fc_emb", "def_fc_emb"]:
nn.init.uniform_(m.weight, -scale * 10, scale * 10)
elif n in ["lp_fc_emb", "oppo_lp_fc_emb"]:
nn.init.uniform_(m.weight, -scale, scale)
elif "fc_emb" in n:
nn.init.uniform_(m.weight, -scale, scale)
def load_embeddings(self, embeddings):
weight = self.id_embed.weight
embeddings = torch.from_numpy(embeddings).to(dtype=weight.dtype, device=weight.device)
unknown_embed = embeddings.mean(dim=0, keepdim=True)
embeddings = torch.cat([unknown_embed, embeddings], dim=0)
weight.data.copy_(embeddings)
def freeze_embeddings(self):
self.id_embed.weight.requires_grad = False
def num_transform(self, x):
return self.num_fc(bytes_to_bin(x, self.bin_points, self.bin_intervals))
def encode_action_(self, x):
x_a_msg = self.a_msg_embed(x[:, :, 0])
x_a_act = self.a_act_embed(x[:, :, 1])
x_a_yesno = self.a_yesno_embed(x[:, :, 2])
x_a_phase = self.a_phase_embed(x[:, :, 3])
x_a_cancel = self.a_cancel_finish_embed(x[:, :, 4])
x_a_position = self.a_position_embed(x[:, :, 5])
x_a_option = self.a_option_embed(x[:, :, 6])
x_a_number = self.a_number_embed(x[:, :, 7])
x_a_place = self.a_place_embed(x[:, :, 8])
x_a_attrib = self.a_attrib_embed(x[:, :, 9])
return x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib
def get_action_card_(self, x, f_cards):
b, n, c = x.shape
m = c // 2
spec_index = x.view(b, n, m, 2)
spec_index = spec_index[..., 0] * 256 + spec_index[..., 1]
mask = spec_index != 0
mask[:, :, 0] = True
spec_index = spec_index.view(b, -1)
B = torch.arange(b, device=spec_index.device)
f_a_actions = f_cards[B[:, None], spec_index]
f_a_actions = f_a_actions.view(b, n, m, -1)
f_a_actions = (f_a_actions * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return f_a_actions
def get_h_action_card_(self, x):
b, n, _ = x.shape
x_ids = x.view(b, n, -1, 2)
x_ids = x_ids[..., 0] * 256 + x_ids[..., 1]
mask = x_ids != 0
mask[:, :, 0] = True
x_ids = self.id_embed(x_ids)
x_ids = self.h_id_fc_emb(x_ids)
x_ids = (x_ids * mask.unsqueeze(-1)).sum(dim=2) / mask.sum(dim=2, keepdim=True)
return x_ids
def encode_card_id(self, x):
x_id = self.id_embed(x)
x_id = self.id_fc_emb(x_id)
x_id = self.id_norm(x_id)
return x_id
def encode_card_feat1(self, x1):
x_owner = self.owner_embed(x1[:, :, 2])
x_position = self.position_embed(x1[:, :, 3])
x_overley = self.overley_embed(x1[:, :, 4])
x_attribute = self.attribute_embed(x1[:, :, 5])
x_race = self.race_embed(x1[:, :, 6])
x_level = self.level_embed(x1[:, :, 7])
x_counter = self.counter_embed(x1[:, :, 8])
return x_owner, x_position, x_overley, x_attribute, x_race, x_level, x_counter
def encode_card_feat2(self, x2):
x_atk = self.num_transform(x2[:, :, 0:2])
x_atk = self.atk_fc_emb(x_atk)
x_def = self.num_transform(x2[:, :, 2:4])
x_def = self.def_fc_emb(x_def)
x_type = self.type_fc_emb(x2[:, :, 4:])
return x_atk, x_def, x_type
def encode_global(self, x):
x_global_1 = x[:, :4].float()
x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:-1].long()
x_g_turn = self.turn_embed(x_global_2[:, 0])
x_g_phase = self.phase_embed(x_global_2[:, 1])
x_g_if_first = self.if_first_embed(x_global_2[:, 2])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3])
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1)
return x_global
def forward(self, x):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_card_ids = x_cards[:, :, :2].long()
x_card_ids = x_card_ids[..., 0] * 256 + x_card_ids[..., 1]
x_cards_1 = x_cards[:, :, 2:11].long()
x_cards_2 = x_cards[:, :, 11:].to(torch.float32)
x_id = self.encode_card_id(x_card_ids)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 0]))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 1]))
x_feat1 = self.encode_card_feat1(x_cards_1)
x_feat2 = self.encode_card_feat2(x_cards_2)
x_feat = torch.cat([*x_feat1, *x_feat2], dim=-1)
x_feat = self.feat_norm(x_feat)
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
x_global = self.global_norm_pre(x_global)
f_global = x_global + self.global_net(x_global)
f_global = self.global_norm(f_global)
f_cards = f_cards + f_global.unsqueeze(1)
x_actions = x_actions.long()
max_multi_select = (x_actions.shape[-1] - 9) // 2
mo = max_multi_select * 2
f_a_cards = self.get_action_card_(x_actions[..., :mo], f_cards)
f_a_cards = f_a_cards + self.a_card_proj(self.a_card_norm(f_a_cards))
x_a_feats = self.encode_action_(x_actions[..., mo:])
x_a_feats = torch.cat(x_a_feats, dim=-1)
f_actions = f_a_cards + self.a_feat_norm(x_a_feats)
mask = x_actions[:, :, mo] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :mo])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, mo:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = self.action_norm(f_actions)
f_s_cards_global = f_cards.mean(dim=1)
c_mask = 1 - mask.unsqueeze(-1).float()
f_s_actions_ha = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid
class Actor(nn.Module):
def __init__(self, channels, use_transformer=False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment