Commit 385bd1cb authored by biluo.shen's avatar biluo.shen

Add PPO selfplay

parent af60d012
......@@ -147,21 +147,21 @@ if __name__ == "__main__":
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval()
# agent = agent.eval()
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
else:
state_dict = None
if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
if state_dict:
agent.load_state_dict(state_dict)
print(agent.load_state_dict(state_dict))
agent = torch.compile(agent, mode='reduce-overhead')
else:
prefix = "_orig_mod."
if state_dict:
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict)
print(agent.load_state_dict(state_dict))
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
......@@ -170,6 +170,7 @@ if __name__ == "__main__":
agent = torch.jit.optimize_for_inference(traced_model)
obs, infos = envs.reset()
next_to_play = infos['to_play']
episode_rewards = []
episode_lengths = []
......@@ -191,7 +192,7 @@ if __name__ == "__main__":
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits, values = agent(obs)
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
if args.play:
......@@ -212,9 +213,11 @@ if __name__ == "__main__":
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play = next_to_play
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start
step += 1
......@@ -225,7 +228,7 @@ if __name__ == "__main__":
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if infos['to_play'][idx] == 0 else -1
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
......
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
......@@ -52,7 +53,7 @@ class Args:
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "self+bot"
play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers: int = 2
......@@ -74,6 +75,12 @@ class Args:
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.6
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
......@@ -95,10 +102,8 @@ class Args:
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: bool = True
"""whether to use torch.compile to compile the model and functions"""
compile_mode: Optional[str] = None
"""the mode to use for torch.compile"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
......@@ -118,6 +123,8 @@ class Args:
"""the probability of logging"""
port: int = 12356
"""the port to use for distributed training"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
......@@ -197,7 +204,7 @@ def run(local_rank, world_size):
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
......@@ -205,7 +212,25 @@ def run(local_rank, world_size):
if local_rank == 0:
print(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file:
embeddings = np.load(args.embedding_file)
......@@ -213,11 +238,14 @@ def run(local_rank, world_size):
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file:
agent.load_embeddings(embeddings)
agent1.load_embeddings(embeddings)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2.load_state_dict(agent1.state_dict())
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optim_params = list(agent1.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
......@@ -225,9 +253,21 @@ def run(local_rank, world_size):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def train_step(agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values):
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long())
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
......@@ -238,7 +278,7 @@ def run(local_rank, world_size):
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
......@@ -269,15 +309,25 @@ def run(local_rank, world_size):
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs):
def predict_step(agent1: Agent, agent2: Agent, next_obs, learn):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, values = agent(next_obs)
return logits, values
logits1, value1, valid = agent1(next_obs)
logits2, value2, valid = agent2(next_obs)
logits = torch.where(learn[:, None], logits1, logits2)
value = torch.where(learn[:, None], value1, value2)
return logits, value
def eval_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits = agent.get_logit(next_obs)
return logits
if args.compile:
train_step = torch.compile(train_step, mode=args.compile_mode)
predict_step = torch.compile(predict_step, mode=args.compile_mode)
train_step = torch.compile(train_step, mode=args.compile)
predict_step = torch.compile(predict_step, mode='default')
# eval_step = torch.compile(eval_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
......@@ -287,12 +337,12 @@ def run(local_rank, world_size):
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
to_plays = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
avg_ep_returns = []
avg_win_rates = []
avg_sp_win_rates = []
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
......@@ -300,8 +350,16 @@ def run(local_rank, world_size):
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8)
next_to_play = to_tensor(info["to_play"])
next_done = torch.zeros(args.local_num_envs, device=device)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player_)
ai_player = to_tensor(ai_player_, dtype=next_to_play.dtype)
next_value = 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
......@@ -319,47 +377,44 @@ def run(local_rank, world_size):
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
to_plays[step] = next_to_play
learn = next_to_play == ai_player
learns[step] = learn
_start = time.time()
logits, value = predict_step(agent, next_obs)
logits, value = predict_step(agent1, agent2, next_obs, learn)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value.flatten()
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_value = torch.where(learn, value, next_value) * (1 - next_done.float())
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play = to_tensor(info["to_play"])
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
env_time += time.time() - _start
rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_)
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool)
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
if info['is_selfplay'][idx]:
# win rate for the first player
pl = 1 if next_to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
avg_sp_win_rates.append(1 - winner)
else:
# win rate of agent
winner = 0 if episode_reward > 0 else 1
avg_win_rates.append(1 - winner)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
......@@ -368,37 +423,62 @@ def run(local_rank, world_size):
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_ep_returns) > n:
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_ep_returns = []
if len(avg_win_rates) > n:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
avg_win_rates = []
if len(avg_sp_win_rates) > n:
writer.add_scalar("charts/avg_sp_win_rate", np.mean(avg_sp_win_rates), global_step)
avg_sp_win_rates = []
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
value = agent1.get_value(next_obs).reshape(-1)
advantages = torch.zeros_like(rewards).to(device)
nextvalues = torch.where(next_to_play == ai_player, value, next_value)
done_used = torch.zeros_like(next_done, dtype=torch.bool)
reward = 0
lastgaelam = 0
next_to_play_ = next_to_play
for t in reversed(range(args.num_steps)):
to_play = to_plays[t]
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
sp = 2.0 * (to_play == next_to_play_).float() - 1.0
delta = rewards[t] + args.gamma * nextvalues * sp * nextnonterminal - values[t]
lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
# TODO: experiment with it
# lastgaelam = lastgaelam * sp
advantages[t] = lastgaelam
next_to_play_ = to_play
# if learns[t]:
# if dones[t+1]:
# reward = rewards[t]
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# if not done_used:
# reward = reward
# nextvalues = 0
# lastgaelam = 0
# done_used = True
# else:
# reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn = learns[t]
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn.int() - 0.5)
reward = torch.where(next_done, rewards[t] * sp, torch.where(learn & done_used, 0, reward))
real_done = next_done | ~done_used
nextvalues = torch.where(real_done, 0, nextvalues)
lastgaelam = torch.where(real_done, 0, lastgaelam)
done_used = torch.where(
next_done, learn, torch.where(learn & ~done_used, True, done_used))
delta = reward + args.gamma * nextvalues - values[t]
advantages[t] = lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
nextvalues = torch.where(learn, values[t], nextvalues)
lastgaelam = torch.where(learn, lastgaelam_, lastgaelam)
returns = advantages + values
_start = time.time()
......@@ -412,6 +492,7 @@ def run(local_rank, world_size):
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -425,10 +506,10 @@ def run(local_rank, world_size):
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds])
reduce_gradidents(agent, args.world_size)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
train_step(agent1, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
......@@ -448,8 +529,8 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if iteration % args.save_interval == 0 or iteration == args.num_iterations:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
if iteration % args.save_interval == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......@@ -471,10 +552,69 @@ def run(local_rank, world_size):
print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step)
if local_rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent2.load_state_dict(agent1.state_dict())
version += 1
if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pth"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
_start = time.time()
episode_lengths = []
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = eval_step(agent1, e_obs)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0:
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return}, eval_ep_len={eval_ep_len}, eval_win_rate={eval_win_rate}")
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth"))
writer.close()
......
......@@ -320,66 +320,88 @@ class Encoder(nn.Module):
return f_actions, f_state, mask, valid
class PPOCritic(nn.Module):
# class PPOCritic(nn.Module):
def __init__(self, channels):
super(PPOCritic, self).__init__()
c = channels
# def __init__(self, channels):
# super(PPOCritic, self).__init__()
# c = channels
self.net = nn.Sequential(
nn.Linear(c * 2, c // 2),
nn.ReLU(),
nn.Linear(c // 2, 1),
)
# self.net = nn.Sequential(
# nn.Linear(c * 2, c // 2),
# nn.ReLU(),
# nn.Linear(c // 2, 1),
# )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
def forward(self, f_state):
return self.net(f_state)
# def forward(self, f_actions, mask, action):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
class PPOActor(nn.Module):
# def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
def __init__(self, channels):
super(PPOActor, self).__init__()
class Actor(nn.Module):
def __init__(self, channels, use_transformer=False):
super(Actor, self).__init__()
c = channels
self.trans = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
self.use_transformer = use_transformer
if use_transformer:
self.transformer = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
self.head = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
def forward(self, f_actions, mask, action):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return probs.log_prob(action), probs.entropy()
def predict(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
def forward(self, f_actions, mask):
if self.use_transformer:
f_actions = self.transformer(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
num_history_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = nn.Sequential(
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.actor = Actor(c, a_trans)
self.critic = nn.Sequential(
nn.Linear(c * 2, c // 2),
......@@ -390,24 +412,15 @@ class PPOAgent(nn.Module):
def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x):
def get_logit(self, x):
f_actions, f_state, mask, valid = self.encoder(x)
return self.critic(f_state)
return self.actor(f_actions, mask)
def get_action_and_value(self, x, action):
def get_value(self, x):
f_actions, f_state, mask, valid = self.encoder(x)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return action, probs.log_prob(action), probs.entropy(), self.critic(f_state), valid
return self.critic(f_state)
def forward(self, x):
f_actions, f_state, mask, valid = self.encoder(x)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f_state)
logits = self.actor(f_actions, mask)
return logits, self.critic(f_state), valid
......@@ -2935,7 +2935,6 @@ private:
return;
}
auto player = read_u8();
to_play_ = player;
auto size = read_u8();
std::vector<Card> cards;
for (int i = 0; i < size; ++i) {
......@@ -3315,7 +3314,6 @@ private:
throw std::runtime_error("Retry");
} else if (msg_ == MSG_SELECT_BATTLECMD) {
auto player = read_u8();
to_play_ = player;
auto activatable = read_cardlist_spec(true);
auto attackable = read_cardlist_spec(true, true);
bool to_m2 = read_u8();
......@@ -3366,6 +3364,7 @@ private:
}
int n_activatables = activatable.size();
int n_attackables = attackable.size();
to_play_ = player;
callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) {
if (idx < n_activatables) {
OCG_SetResponsei(pduel_, idx << 16);
......@@ -3382,7 +3381,6 @@ private:
};
} else if (msg_ == MSG_SELECT_UNSELECT_CARD) {
auto player = read_u8();
to_play_ = player;
bool finishable = read_u8();
bool cancelable = read_u8();
auto min = read_u8();
......@@ -3435,6 +3433,7 @@ private:
// cancelable and finishable not needed
to_play_ = player;
callback_ = [this](int idx) {
if (options_[idx] == "f") {
OCG_SetResponsei(pduel_, -1);
......@@ -3447,7 +3446,6 @@ private:
} else if (msg_ == MSG_SELECT_CARD) {
auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8();
auto min = read_u8();
auto max = read_u8();
......@@ -3535,6 +3533,7 @@ private:
}
}
to_play_ = player;
callback_ = [this, combs](int idx) {
const auto &comb = combs[idx];
resp_buf_[0] = comb.size();
......@@ -3545,7 +3544,6 @@ private:
};
} else if (msg_ == MSG_SELECT_TRIBUTE) {
auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8();
auto min = read_u8();
auto max = read_u8();
......@@ -3621,6 +3619,7 @@ private:
options_.push_back(option);
}
to_play_ = player;
callback_ = [this, combs](int idx) {
const auto &comb = combs[idx];
resp_buf_[0] = comb.size();
......@@ -3632,7 +3631,6 @@ private:
} else if (msg_ == MSG_SELECT_SUM) {
auto mode = read_u8();
auto player = read_u8();
to_play_ = player;
auto val = read_u32();
auto min = read_u8();
auto max = read_u8();
......@@ -3761,6 +3759,7 @@ private:
options_.push_back(option);
}
to_play_ = player;
callback_ = [this, combs, must_select_size](int idx) {
const auto &comb = combs[idx];
resp_buf_[0] = must_select_size + comb.size();
......@@ -3775,7 +3774,6 @@ private:
} else if (msg_ == MSG_SELECT_CHAIN) {
auto player = read_u8();
to_play_ = player;
auto size = read_u8();
auto spe_count = read_u8();
bool forced = read_u8();
......@@ -3872,6 +3870,7 @@ private:
if (!forced) {
options_.push_back("c");
}
to_play_ = player;
callback_ = [this, forced](int idx) {
const auto &option = options_[idx];
if ((option == "c") && (!forced)) {
......@@ -3882,7 +3881,6 @@ private:
};
} else if (msg_ == MSG_SELECT_YESNO) {
auto player = read_u8();
to_play_ = player;
if (verbose_) {
auto desc = read_u32();
......@@ -3907,6 +3905,7 @@ private:
dp_ += 4;
}
options_ = {"y", "n"};
to_play_ = player;
callback_ = [this](int idx) {
if (idx == 0) {
OCG_SetResponsei(pduel_, 1);
......@@ -3918,7 +3917,6 @@ private:
};
} else if (msg_ == MSG_SELECT_EFFECTYN) {
auto player = read_u8();
to_play_ = player;
std::string spec;
if (verbose_) {
......@@ -3981,6 +3979,7 @@ private:
spec = ls_to_spec(loc, seq, pos, c != player);
}
options_ = {"y " + spec, "n " + spec};
to_play_ = player;
callback_ = [this](int idx) {
if (idx == 0) {
OCG_SetResponsei(pduel_, 1);
......@@ -3992,7 +3991,6 @@ private:
};
} else if (msg_ == MSG_SELECT_OPTION) {
auto player = read_u8();
to_play_ = player;
auto size = read_u8();
if (verbose_) {
auto pl = players_[player];
......@@ -4016,6 +4014,7 @@ private:
options_.push_back(std::to_string(i + 1));
}
}
to_play_ = player;
callback_ = [this](int idx) {
if (verbose_) {
players_[to_play_]->notify("You selected option " + options_[idx] +
......@@ -4029,7 +4028,6 @@ private:
};
} else if (msg_ == MSG_SELECT_IDLECMD) {
int32_t player = read_u8();
to_play_ = player;
auto summonable_ = read_cardlist_spec();
auto spsummon_ = read_cardlist_spec();
auto repos_ = read_cardlist_spec();
......@@ -4134,6 +4132,7 @@ private:
}
}
to_play_ = player;
callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset,
activate_offset](int idx) {
const auto &option = options_[idx];
......@@ -4169,7 +4168,6 @@ private:
};
} else if (msg_ == MSG_SELECT_PLACE) {
auto player = read_u8();
to_play_ = player;
auto count = read_u8();
if (count == 0) {
count = 1;
......@@ -4189,6 +4187,7 @@ private:
" places for card, from " + specs_str + ".");
}
}
to_play_ = player;
callback_ = [this, player](int idx) {
int y = player + 1;
std::string spec = options_[idx];
......@@ -4205,7 +4204,6 @@ private:
};
} else if (msg_ == MSG_SELECT_DISFIELD) {
auto player = read_u8();
to_play_ = player;
auto count = read_u8();
if (count == 0) {
count = 1;
......@@ -4225,6 +4223,7 @@ private:
std::to_string(count) + " not implemented");
}
}
to_play_ = player;
callback_ = [this, player](int idx) {
int y = player + 1;
std::string spec = options_[idx];
......@@ -4241,7 +4240,6 @@ private:
};
} else if (msg_ == MSG_ANNOUNCE_NUMBER) {
auto player = read_u8();
to_play_ = player;
auto count = read_u8();
std::vector<int> numbers;
for (int i = 0; i < count; ++i) {
......@@ -4265,12 +4263,12 @@ private:
str += "]";
pl->notify(str);
}
to_play_ = player;
callback_ = [this](int idx) {
OCG_SetResponsei(pduel_, idx);
};
} else if (msg_ == MSG_ANNOUNCE_ATTRIB) {
auto player = read_u8();
to_play_ = player;
auto count = read_u8();
auto flag = read_u32();
......@@ -4310,6 +4308,7 @@ private:
options_.push_back(option);
}
to_play_ = player;
callback_ = [this](int idx) {
const auto &option = options_[idx];
uint32_t resp = 0;
......@@ -4323,7 +4322,6 @@ private:
} else if (msg_ == MSG_SELECT_POSITION) {
auto player = read_u8();
to_play_ = player;
auto code = read_u32();
auto valid_pos = read_u8();
......@@ -4348,6 +4346,7 @@ private:
i++;
}
to_play_ = player;
callback_ = [this](int idx) {
uint8_t pos = options_[idx][0] - '1';
OCG_SetResponsei(pduel_, 1 << pos);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment