Commit e7a19464 authored by biluo.shen's avatar biluo.shen

Improve for compile

parent eba0e134
...@@ -14,7 +14,7 @@ import tyro ...@@ -14,7 +14,7 @@ import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
...@@ -171,8 +171,13 @@ if __name__ == "__main__": ...@@ -171,8 +171,13 @@ if __name__ == "__main__":
_start = time.time() _start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad(): with torch.no_grad():
values = agent(obs)[0] logits, values = agent(obs)
actions = torch.argmax(values, dim=1).cpu().numpy() probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
if args.play:
print(probs[probs != 0].tolist())
print(values)
actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
else: else:
if args.strategy == "random": if args.strategy == "random":
......
...@@ -13,7 +13,9 @@ import tyro ...@@ -13,7 +13,9 @@ import tyro
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
...@@ -44,7 +46,7 @@ class Args: ...@@ -44,7 +46,7 @@ class Args:
"""the deck file for the second player""" """the deck file for the second player"""
code_list_file: str = "code_list.txt" code_list_file: str = "code_list.txt"
"""the code list file for card embeddings""" """the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy" embedding_file: Optional[str] = "embeddings_en.npy"
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
...@@ -101,6 +103,10 @@ class Args: ...@@ -101,6 +103,10 @@ class Args:
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs" tb_dir: str = "./runs"
"""tensorboard log directory""" """tensorboard log directory"""
...@@ -199,20 +205,28 @@ def run(local_rank, world_size): ...@@ -199,20 +205,28 @@ def run(local_rank, world_size):
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.embedding_file:
embeddings = np.load(args.embedding_file) embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file:
agent.load_embeddings(embeddings) agent.load_embeddings(embeddings)
if args.compile: # if args.compile:
agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode) # agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train)
def masked_mean(x, valid): def masked_mean(x, valid):
x = x.masked_fill(~valid, 0) x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum() return x.sum() / valid.float().sum()
def train_step(agent, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values): def train_step(agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
with autocast(enabled=args.fp16_train):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long()) _, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
logratio = newlogprob - mb_logprobs logratio = newlogprob - mb_logprobs
ratio = logratio.exp() ratio = logratio.exp()
...@@ -251,12 +265,20 @@ def run(local_rank, world_size): ...@@ -251,12 +265,20 @@ def run(local_rank, world_size):
entropy_loss = masked_mean(entropy, valid) entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer)
reduce_gradidents(agent, args.world_size) reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, values = agent(next_obs)
return logits, values
if args.compile: if args.compile:
train_step = torch.compile(train_step, mode=args.compile_mode) train_step = torch.compile(train_step, mode=args.compile_mode)
predict_step = torch.compile(predict_step, mode=args.compile_mode)
def to_tensor(x, dtype=torch.float32): def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
...@@ -296,8 +318,10 @@ def run(local_rank, world_size): ...@@ -296,8 +318,10 @@ def run(local_rank, world_size):
dones[step] = next_done dones[step] = next_done
_start = time.time() _start = time.time()
with torch.no_grad(): logits, value = predict_step(agent, next_obs)
action, logprob, _, value, valid = agent.get_action_and_value(next_obs) probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value.flatten() values[step] = value.flatten()
actions[step] = action actions[step] = action
...@@ -374,10 +398,11 @@ def run(local_rank, world_size): ...@@ -374,10 +398,11 @@ def run(local_rank, world_size):
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds])
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step() scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl: if args.target_kl is not None and approx_kl > args.target_kl:
......
...@@ -45,7 +45,7 @@ class Encoder(nn.Module): ...@@ -45,7 +45,7 @@ class Encoder(nn.Module):
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False) self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None: if embedding_shape is None:
n_embed, embed_dim = 150, 1024 n_embed, embed_dim = 1000, 1024
else: else:
n_embed, embed_dim = embedding_shape n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown n_embed = 1 + n_embed # 1 (index 0) for unknown
...@@ -339,19 +339,28 @@ class PPOAgent(nn.Module): ...@@ -339,19 +339,28 @@ class PPOAgent(nn.Module):
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1) f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
return self.critic(f) return self.critic(f)
def get_action_and_value(self, x, action=None): def get_action_and_value(self, x, action):
f_actions, mask, valid = self.encoder(x) f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float() c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1) f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0] logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf")) logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid
def forward(self, x):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f)
class DMCAgent(nn.Module): class DMCAgent(nn.Module):
......
...@@ -39,6 +39,7 @@ def mp_start(run): ...@@ -39,6 +39,7 @@ def mp_start(run):
if world_size == 1: if world_size == 1:
run(local_rank=0, world_size=world_size) run(local_rank=0, world_size=world_size)
else: else:
mp.set_start_method('spawn')
children = [] children = []
for i in range(world_size): for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size)) subproc = mp.Process(target=run, args=(i, world_size))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment