Commit e7a19464 authored by biluo.shen's avatar biluo.shen

Improve for compile

parent eba0e134
......@@ -14,7 +14,7 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import Agent
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
......@@ -171,8 +171,13 @@ if __name__ == "__main__":
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
values = agent(obs)[0]
actions = torch.argmax(values, dim=1).cpu().numpy()
logits, values = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
if args.play:
print(probs[probs != 0].tolist())
print(values)
actions = probs.argmax(axis=1)
model_time += time.time() - _start
else:
if args.strategy == "random":
......
......@@ -13,7 +13,9 @@ import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
......@@ -44,7 +46,7 @@ class Args:
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: str = "embeddings_en.npy"
embedding_file: Optional[str] = "embeddings_en.npy"
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
......@@ -101,6 +103,10 @@ class Args:
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
......@@ -199,21 +205,29 @@ def run(local_rank, world_size):
envs = RecordEpisodeStatistics(envs)
embeddings = np.load(args.embedding_file)
if args.embedding_file:
embeddings = np.load(args.embedding_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
agent.load_embeddings(embeddings)
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file:
agent.load_embeddings(embeddings)
if args.compile:
agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
# if args.compile:
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def train_step(agent, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
def train_step(agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
with autocast(enabled=args.fp16_train):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
......@@ -251,12 +265,20 @@ def run(local_rank, world_size):
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, values = agent(next_obs)
return logits, values
if args.compile:
train_step = torch.compile(train_step, mode=args.compile_mode)
predict_step = torch.compile(predict_step, mode=args.compile_mode)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
......@@ -296,8 +318,10 @@ def run(local_rank, world_size):
dones[step] = next_done
_start = time.time()
with torch.no_grad():
action, logprob, _, value, valid = agent.get_action_and_value(next_obs)
logits, value = predict_step(agent, next_obs)
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value.flatten()
actions[step] = action
......@@ -374,10 +398,11 @@ def run(local_rank, world_size):
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds])
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
......
......@@ -45,7 +45,7 @@ class Encoder(nn.Module):
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
if embedding_shape is None:
n_embed, embed_dim = 150, 1024
n_embed, embed_dim = 1000, 1024
else:
n_embed, embed_dim = embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
......@@ -339,20 +339,29 @@ class PPOAgent(nn.Module):
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
return self.critic(f)
def get_action_and_value(self, x, action=None):
def get_action_and_value(self, x, action):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(f), valid
def forward(self, x):
f_actions, mask, valid = self.encoder(x)
c_mask = 1 - mask.unsqueeze(-1).float()
f = (f_actions * c_mask).sum(dim=1) / c_mask.sum(dim=1)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f)
class DMCAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
......
......@@ -39,6 +39,7 @@ def mp_start(run):
if world_size == 1:
run(local_rank=0, world_size=world_size)
else:
mp.set_start_method('spawn')
children = []
for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment