Commit 385bd1cb authored by biluo.shen's avatar biluo.shen

Add PPO selfplay

parent af60d012
...@@ -147,21 +147,21 @@ if __name__ == "__main__": ...@@ -147,21 +147,21 @@ if __name__ == "__main__":
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent = agent.eval() # agent = agent.eval()
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
else: else:
state_dict = None state_dict = None
if args.compile: if args.compile:
agent = torch.compile(agent, mode='reduce-overhead')
if state_dict: if state_dict:
agent.load_state_dict(state_dict) print(agent.load_state_dict(state_dict))
agent = torch.compile(agent, mode='reduce-overhead')
else: else:
prefix = "_orig_mod." prefix = "_orig_mod."
if state_dict: if state_dict:
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict) print(agent.load_state_dict(state_dict))
if args.optimize: if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device) obs = create_obs(envs.observation_space, (num_envs,), device=device)
...@@ -170,6 +170,7 @@ if __name__ == "__main__": ...@@ -170,6 +170,7 @@ if __name__ == "__main__":
agent = torch.jit.optimize_for_inference(traced_model) agent = torch.jit.optimize_for_inference(traced_model)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play']
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -191,7 +192,7 @@ if __name__ == "__main__": ...@@ -191,7 +192,7 @@ if __name__ == "__main__":
_start = time.time() _start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad(): with torch.no_grad():
logits, values = agent(obs) logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy() probs = probs.cpu().numpy()
if args.play: if args.play:
...@@ -212,9 +213,11 @@ if __name__ == "__main__": ...@@ -212,9 +213,11 @@ if __name__ == "__main__":
# print(k, v.tolist()) # print(k, v.tolist())
# print(infos) # print(infos)
# print(actions[0]) # print(actions[0])
to_play = next_to_play
_start = time.time() _start = time.time()
obs, rewards, dones, infos = envs.step(actions) obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start env_time += time.time() - _start
step += 1 step += 1
...@@ -225,7 +228,7 @@ if __name__ == "__main__": ...@@ -225,7 +228,7 @@ if __name__ == "__main__":
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] episode_reward = infos['r'][idx]
if args.selfplay: if args.selfplay:
pl = 1 if infos['to_play'][idx] == 0 else -1 pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1 winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner win = 1 - winner
else: else:
......
import os import os
import random import random
import time import time
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Literal, Optional
...@@ -52,7 +53,7 @@ class Args: ...@@ -52,7 +53,7 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 16
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "self+bot" play_mode: str = "bot"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'""" """the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
num_layers: int = 2 num_layers: int = 2
...@@ -74,6 +75,12 @@ class Args: ...@@ -74,6 +75,12 @@ class Args:
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
update_win_rate: float = 0.6
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
...@@ -95,10 +102,8 @@ class Args: ...@@ -95,10 +102,8 @@ class Args:
backend: Literal["gloo", "nccl", "mpi"] = "nccl" backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training""" """the backend for distributed training"""
compile: bool = True compile: Optional[str] = None
"""whether to use torch.compile to compile the model and functions""" """Compile mode of torch.compile, None for no compilation"""
compile_mode: Optional[str] = None
"""the mode to use for torch.compile"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None env_threads: Optional[int] = None
...@@ -118,6 +123,8 @@ class Args: ...@@ -118,6 +123,8 @@ class Args:
"""the probability of logging""" """the probability of logging"""
port: int = 12356 port: int = 12356
"""the port to use for distributed training""" """the port to use for distributed training"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
# to be filled in runtime # to be filled in runtime
local_batch_size: int = 0 local_batch_size: int = 0
...@@ -197,7 +204,7 @@ def run(local_rank, world_size): ...@@ -197,7 +204,7 @@ def run(local_rank, world_size):
deck2=args.deck2, deck2=args.deck2,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode=args.play_mode, play_mode='self',
) )
envs.num_envs = args.local_num_envs envs.num_envs = args.local_num_envs
obs_space = envs.observation_space obs_space = envs.observation_space
...@@ -205,7 +212,25 @@ def run(local_rank, world_size): ...@@ -205,7 +212,25 @@ def run(local_rank, world_size):
if local_rank == 0: if local_rank == 0:
print(f"obs_space={obs_space}, action_shape={action_shape}") print(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = np.load(args.embedding_file) embeddings = np.load(args.embedding_file)
...@@ -213,11 +238,14 @@ def run(local_rank, world_size): ...@@ -213,11 +238,14 @@ def run(local_rank, world_size):
else: else:
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device) agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
if args.embedding_file: if args.embedding_file:
agent.load_embeddings(embeddings) agent1.load_embeddings(embeddings)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2.load_state_dict(agent1.state_dict())
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optim_params = list(agent1.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
...@@ -225,9 +253,21 @@ def run(local_rank, world_size): ...@@ -225,9 +253,21 @@ def run(local_rank, world_size):
x = x.masked_fill(~valid, 0) x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum() return x.sum() / valid.float().sum()
def train_step(agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values): def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train): with autocast(enabled=args.fp16_train):
_, newlogprob, entropy, newvalue, valid = agent.get_action_and_value(mb_obs, mb_actions.long()) logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs logratio = newlogprob - mb_logprobs
ratio = logratio.exp() ratio = logratio.exp()
...@@ -238,7 +278,7 @@ def run(local_rank, world_size): ...@@ -238,7 +278,7 @@ def run(local_rank, world_size):
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv: if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss # Policy loss
pg_loss1 = -mb_advantages * ratio pg_loss1 = -mb_advantages * ratio
...@@ -269,15 +309,25 @@ def run(local_rank, world_size): ...@@ -269,15 +309,25 @@ def run(local_rank, world_size):
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent, next_obs): def predict_step(agent1: Agent, agent2: Agent, next_obs, learn):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): with autocast(enabled=args.fp16_eval):
logits, values = agent(next_obs) logits1, value1, valid = agent1(next_obs)
return logits, values logits2, value2, valid = agent2(next_obs)
logits = torch.where(learn[:, None], logits1, logits2)
value = torch.where(learn[:, None], value1, value2)
return logits, value
def eval_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits = agent.get_logit(next_obs)
return logits
if args.compile: if args.compile:
train_step = torch.compile(train_step, mode=args.compile_mode) train_step = torch.compile(train_step, mode=args.compile)
predict_step = torch.compile(predict_step, mode=args.compile_mode) predict_step = torch.compile(predict_step, mode='default')
# eval_step = torch.compile(eval_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32): def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
...@@ -287,12 +337,12 @@ def run(local_rank, world_size): ...@@ -287,12 +337,12 @@ def run(local_rank, world_size):
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device) actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device) logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device) dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device) values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
to_plays = torch.zeros((args.num_steps, args.local_num_envs)).to(device) learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = [] avg_ep_returns = deque(maxlen=1000)
avg_win_rates = [] avg_win_rates = deque(maxlen=1000)
avg_sp_win_rates = [] version = 0
# TRY NOT TO MODIFY: start the game # TRY NOT TO MODIFY: start the game
global_step = 0 global_step = 0
...@@ -300,8 +350,16 @@ def run(local_rank, world_size): ...@@ -300,8 +350,16 @@ def run(local_rank, world_size):
start_time = time.time() start_time = time.time()
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8) next_obs = to_tensor(next_obs, dtype=torch.uint8)
next_to_play = to_tensor(info["to_play"]) next_to_play_ = info["to_play"]
next_done = torch.zeros(args.local_num_envs, device=device) next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player_)
ai_player = to_tensor(ai_player_, dtype=next_to_play.dtype)
next_value = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
...@@ -319,47 +377,44 @@ def run(local_rank, world_size): ...@@ -319,47 +377,44 @@ def run(local_rank, world_size):
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
to_plays[step] = next_to_play learn = next_to_play == ai_player
learns[step] = learn
_start = time.time() _start = time.time()
logits, value = predict_step(agent, next_obs) logits, value = predict_step(agent1, agent2, next_obs, learn)
value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
logprob = probs.log_prob(action) logprob = probs.log_prob(action)
values[step] = value.flatten() values[step] = value
actions[step] = action actions[step] = action
logprobs[step] = logprob logprobs[step] = logprob
action = action.cpu().numpy() action = action.cpu().numpy()
model_time += time.time() - _start model_time += time.time() - _start
next_value = torch.where(learn, value, next_value) * (1 - next_done.float())
_start = time.time() _start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action) next_obs, reward, next_done_, info = envs.step(action)
next_to_play = to_tensor(info["to_play"]) next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward) rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_) next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool)
collect_time = time.time() - collect_start
print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
if not writer: if not writer:
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
if info['is_selfplay'][idx]: avg_win_rates.append(win)
# win rate for the first player
pl = 1 if next_to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
avg_sp_win_rates.append(1 - winner)
else:
# win rate of agent
winner = 0 if episode_reward > 0 else 1
avg_win_rates.append(1 - winner)
if random.random() < args.log_p: if random.random() < args.log_p:
n = 100 n = 100
...@@ -368,37 +423,62 @@ def run(local_rank, world_size): ...@@ -368,37 +423,62 @@ def run(local_rank, world_size):
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_ep_returns) > n: if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step) writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_ep_returns = []
if len(avg_win_rates) > n:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
avg_win_rates = []
if len(avg_sp_win_rates) > n: collect_time = time.time() - collect_start
writer.add_scalar("charts/avg_sp_win_rate", np.mean(avg_sp_win_rates), global_step) print(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}", flush=True)
avg_sp_win_rates = []
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1) value = agent1.get_value(next_obs).reshape(-1)
advantages = torch.zeros_like(rewards).to(device) advantages = torch.zeros_like(rewards).to(device)
nextvalues = torch.where(next_to_play == ai_player, value, next_value)
done_used = torch.zeros_like(next_done, dtype=torch.bool)
reward = 0
lastgaelam = 0 lastgaelam = 0
next_to_play_ = next_to_play
for t in reversed(range(args.num_steps)): for t in reversed(range(args.num_steps)):
to_play = to_plays[t] # if learns[t]:
if t == args.num_steps - 1: # if dones[t+1]:
nextnonterminal = 1.0 - next_done # reward = rewards[t]
nextvalues = next_value # nextvalues = 0
else: # lastgaelam = 0
nextnonterminal = 1.0 - dones[t + 1] # done_used = True
nextvalues = values[t + 1] # else:
sp = 2.0 * (to_play == next_to_play_).float() - 1.0 # if not done_used:
delta = rewards[t] + args.gamma * nextvalues * sp * nextnonterminal - values[t] # reward = reward
lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam # nextvalues = 0
# TODO: experiment with it # lastgaelam = 0
# lastgaelam = lastgaelam * sp # done_used = True
advantages[t] = lastgaelam # else:
next_to_play_ = to_play # reward = rewards[t]
# delta = reward + args.gamma * nextvalues - values[t]
# lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
# advantages[t] = lastgaelam_
# nextvalues = values[t]
# lastgaelam = lastgaelam_
# else:
# if dones[t+1]:
# reward = -rewards[t]
# done_used = False
# else:
# reward = reward
learn = learns[t]
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn.int() - 0.5)
reward = torch.where(next_done, rewards[t] * sp, torch.where(learn & done_used, 0, reward))
real_done = next_done | ~done_used
nextvalues = torch.where(real_done, 0, nextvalues)
lastgaelam = torch.where(real_done, 0, lastgaelam)
done_used = torch.where(
next_done, learn, torch.where(learn & ~done_used, True, done_used))
delta = reward + args.gamma * nextvalues - values[t]
advantages[t] = lastgaelam_ = delta + args.gamma * args.gae_lambda * lastgaelam
nextvalues = torch.where(learn, values[t], nextvalues)
lastgaelam = torch.where(learn, lastgaelam_, lastgaelam)
returns = advantages + values returns = advantages + values
_start = time.time() _start = time.time()
...@@ -412,6 +492,7 @@ def run(local_rank, world_size): ...@@ -412,6 +492,7 @@ def run(local_rank, world_size):
b_advantages = advantages.reshape(-1) b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1) b_returns = returns.reshape(-1)
b_values = values.reshape(-1) b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -425,10 +506,10 @@ def run(local_rank, world_size): ...@@ -425,10 +506,10 @@ def run(local_rank, world_size):
k: v[mb_inds] for k, v in b_obs.items() k: v[mb_inds] for k, v in b_obs.items()
} }
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent1, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(agent, args.world_size) reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
...@@ -448,8 +529,8 @@ def run(local_rank, world_size): ...@@ -448,8 +529,8 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if local_rank == 0:
if iteration % args.save_interval == 0 or iteration == args.num_iterations: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth")) torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
...@@ -471,10 +552,69 @@ def run(local_rank, world_size): ...@@ -471,10 +552,69 @@ def run(local_rank, world_size):
print("SPS:", SPS) print("SPS:", SPS)
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if local_rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent2.load_state_dict(agent1.state_dict())
version += 1
if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pth"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
_start = time.time()
episode_lengths = []
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_logits = eval_step(agent1, e_obs)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0:
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return}, eval_ep_len={eval_ep_len}, eval_win_rate={eval_win_rate}")
if args.world_size > 1: if args.world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
if local_rank == 0: if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth"))
writer.close() writer.close()
......
...@@ -320,28 +320,60 @@ class Encoder(nn.Module): ...@@ -320,28 +320,60 @@ class Encoder(nn.Module):
return f_actions, f_state, mask, valid return f_actions, f_state, mask, valid
class PPOCritic(nn.Module): # class PPOCritic(nn.Module):
def __init__(self, channels): # def __init__(self, channels):
super(PPOCritic, self).__init__() # super(PPOCritic, self).__init__()
c = channels # c = channels
self.net = nn.Sequential( # self.net = nn.Sequential(
nn.Linear(c * 2, c // 2), # nn.Linear(c * 2, c // 2),
nn.ReLU(), # nn.ReLU(),
nn.Linear(c // 2, 1), # nn.Linear(c // 2, 1),
) # )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
def forward(self, f_state): # def forward(self, f_actions, mask, action):
return self.net(f_state) # f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
class PPOActor(nn.Module): # def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
def __init__(self, channels):
super(PPOActor, self).__init__() class Actor(nn.Module):
def __init__(self, channels, use_transformer=False):
super(Actor, self).__init__()
c = channels c = channels
self.trans = nn.TransformerEncoderLayer( self.use_transformer = use_transformer
if use_transformer:
self.transformer = nn.TransformerEncoderLayer(
c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False) c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
self.head = nn.Sequential( self.head = nn.Sequential(
nn.Linear(c, c // 4), nn.Linear(c, c // 4),
...@@ -349,37 +381,27 @@ class PPOActor(nn.Module): ...@@ -349,37 +381,27 @@ class PPOActor(nn.Module):
nn.Linear(c // 4, 1), nn.Linear(c // 4, 1),
) )
def forward(self, f_actions, mask, action): def forward(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask) if self.use_transformer:
logits = self.head(f_actions)[..., 0] f_actions = self.transformer(f_actions, src_key_padding_mask=mask)
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return probs.log_prob(action), probs.entropy()
def predict(self, f_actions, mask):
f_actions = self.trans(f_actions, src_key_padding_mask=mask)
logits = self.head(f_actions)[..., 0] logits = self.head(f_actions)[..., 0]
logits = logits.float() logits = logits.float()
logits = logits.masked_fill(mask, float("-inf")) logits = logits.masked_fill(mask, float("-inf"))
return logits return logits
class PPOAgent(nn.Module): class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True): num_history_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__() super(PPOAgent, self).__init__()
self.encoder = Encoder( self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine) channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
c = channels c = channels
self.actor = nn.Sequential( self.actor = Actor(c, a_trans)
nn.Linear(c, c // 4),
nn.ReLU(),
nn.Linear(c // 4, 1),
)
self.critic = nn.Sequential( self.critic = nn.Sequential(
nn.Linear(c * 2, c // 2), nn.Linear(c * 2, c // 2),
...@@ -390,24 +412,15 @@ class PPOAgent(nn.Module): ...@@ -390,24 +412,15 @@ class PPOAgent(nn.Module):
def load_embeddings(self, embeddings, freeze=True): def load_embeddings(self, embeddings, freeze=True):
self.encoder.load_embeddings(embeddings, freeze) self.encoder.load_embeddings(embeddings, freeze)
def get_value(self, x): def get_logit(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
return self.critic(f_state) return self.actor(f_actions, mask)
def get_action_and_value(self, x, action): def get_value(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
return self.critic(f_state)
logits = self.actor(f_actions)[..., 0]
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
probs = Categorical(logits=logits)
return action, probs.log_prob(action), probs.entropy(), self.critic(f_state), valid
def forward(self, x): def forward(self, x):
f_actions, f_state, mask, valid = self.encoder(x) f_actions, f_state, mask, valid = self.encoder(x)
logits = self.actor(f_actions, mask)
logits = self.actor(f_actions)[..., 0] return logits, self.critic(f_state), valid
logits = logits.float()
logits = logits.masked_fill(mask, float("-inf"))
return logits, self.critic(f_state)
...@@ -2935,7 +2935,6 @@ private: ...@@ -2935,7 +2935,6 @@ private:
return; return;
} }
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
std::vector<Card> cards; std::vector<Card> cards;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
...@@ -3315,7 +3314,6 @@ private: ...@@ -3315,7 +3314,6 @@ private:
throw std::runtime_error("Retry"); throw std::runtime_error("Retry");
} else if (msg_ == MSG_SELECT_BATTLECMD) { } else if (msg_ == MSG_SELECT_BATTLECMD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto activatable = read_cardlist_spec(true); auto activatable = read_cardlist_spec(true);
auto attackable = read_cardlist_spec(true, true); auto attackable = read_cardlist_spec(true, true);
bool to_m2 = read_u8(); bool to_m2 = read_u8();
...@@ -3366,6 +3364,7 @@ private: ...@@ -3366,6 +3364,7 @@ private:
} }
int n_activatables = activatable.size(); int n_activatables = activatable.size();
int n_attackables = attackable.size(); int n_attackables = attackable.size();
to_play_ = player;
callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) { callback_ = [this, n_activatables, n_attackables, to_ep, to_m2](int idx) {
if (idx < n_activatables) { if (idx < n_activatables) {
OCG_SetResponsei(pduel_, idx << 16); OCG_SetResponsei(pduel_, idx << 16);
...@@ -3382,7 +3381,6 @@ private: ...@@ -3382,7 +3381,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_UNSELECT_CARD) { } else if (msg_ == MSG_SELECT_UNSELECT_CARD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool finishable = read_u8(); bool finishable = read_u8();
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
...@@ -3435,6 +3433,7 @@ private: ...@@ -3435,6 +3433,7 @@ private:
// cancelable and finishable not needed // cancelable and finishable not needed
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (options_[idx] == "f") { if (options_[idx] == "f") {
OCG_SetResponsei(pduel_, -1); OCG_SetResponsei(pduel_, -1);
...@@ -3447,7 +3446,6 @@ private: ...@@ -3447,7 +3446,6 @@ private:
} else if (msg_ == MSG_SELECT_CARD) { } else if (msg_ == MSG_SELECT_CARD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3535,6 +3533,7 @@ private: ...@@ -3535,6 +3533,7 @@ private:
} }
} }
to_play_ = player;
callback_ = [this, combs](int idx) { callback_ = [this, combs](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = comb.size(); resp_buf_[0] = comb.size();
...@@ -3545,7 +3544,6 @@ private: ...@@ -3545,7 +3544,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_TRIBUTE) { } else if (msg_ == MSG_SELECT_TRIBUTE) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
bool cancelable = read_u8(); bool cancelable = read_u8();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3621,6 +3619,7 @@ private: ...@@ -3621,6 +3619,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this, combs](int idx) { callback_ = [this, combs](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = comb.size(); resp_buf_[0] = comb.size();
...@@ -3632,7 +3631,6 @@ private: ...@@ -3632,7 +3631,6 @@ private:
} else if (msg_ == MSG_SELECT_SUM) { } else if (msg_ == MSG_SELECT_SUM) {
auto mode = read_u8(); auto mode = read_u8();
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto val = read_u32(); auto val = read_u32();
auto min = read_u8(); auto min = read_u8();
auto max = read_u8(); auto max = read_u8();
...@@ -3761,6 +3759,7 @@ private: ...@@ -3761,6 +3759,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this, combs, must_select_size](int idx) { callback_ = [this, combs, must_select_size](int idx) {
const auto &comb = combs[idx]; const auto &comb = combs[idx];
resp_buf_[0] = must_select_size + comb.size(); resp_buf_[0] = must_select_size + comb.size();
...@@ -3775,7 +3774,6 @@ private: ...@@ -3775,7 +3774,6 @@ private:
} else if (msg_ == MSG_SELECT_CHAIN) { } else if (msg_ == MSG_SELECT_CHAIN) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
auto spe_count = read_u8(); auto spe_count = read_u8();
bool forced = read_u8(); bool forced = read_u8();
...@@ -3872,6 +3870,7 @@ private: ...@@ -3872,6 +3870,7 @@ private:
if (!forced) { if (!forced) {
options_.push_back("c"); options_.push_back("c");
} }
to_play_ = player;
callback_ = [this, forced](int idx) { callback_ = [this, forced](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
if ((option == "c") && (!forced)) { if ((option == "c") && (!forced)) {
...@@ -3882,7 +3881,6 @@ private: ...@@ -3882,7 +3881,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_YESNO) { } else if (msg_ == MSG_SELECT_YESNO) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
if (verbose_) { if (verbose_) {
auto desc = read_u32(); auto desc = read_u32();
...@@ -3907,6 +3905,7 @@ private: ...@@ -3907,6 +3905,7 @@ private:
dp_ += 4; dp_ += 4;
} }
options_ = {"y", "n"}; options_ = {"y", "n"};
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
OCG_SetResponsei(pduel_, 1); OCG_SetResponsei(pduel_, 1);
...@@ -3918,7 +3917,6 @@ private: ...@@ -3918,7 +3917,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_EFFECTYN) { } else if (msg_ == MSG_SELECT_EFFECTYN) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
std::string spec; std::string spec;
if (verbose_) { if (verbose_) {
...@@ -3981,6 +3979,7 @@ private: ...@@ -3981,6 +3979,7 @@ private:
spec = ls_to_spec(loc, seq, pos, c != player); spec = ls_to_spec(loc, seq, pos, c != player);
} }
options_ = {"y " + spec, "n " + spec}; options_ = {"y " + spec, "n " + spec};
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (idx == 0) { if (idx == 0) {
OCG_SetResponsei(pduel_, 1); OCG_SetResponsei(pduel_, 1);
...@@ -3992,7 +3991,6 @@ private: ...@@ -3992,7 +3991,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_OPTION) { } else if (msg_ == MSG_SELECT_OPTION) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto size = read_u8(); auto size = read_u8();
if (verbose_) { if (verbose_) {
auto pl = players_[player]; auto pl = players_[player];
...@@ -4016,6 +4014,7 @@ private: ...@@ -4016,6 +4014,7 @@ private:
options_.push_back(std::to_string(i + 1)); options_.push_back(std::to_string(i + 1));
} }
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
if (verbose_) { if (verbose_) {
players_[to_play_]->notify("You selected option " + options_[idx] + players_[to_play_]->notify("You selected option " + options_[idx] +
...@@ -4029,7 +4028,6 @@ private: ...@@ -4029,7 +4028,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_IDLECMD) { } else if (msg_ == MSG_SELECT_IDLECMD) {
int32_t player = read_u8(); int32_t player = read_u8();
to_play_ = player;
auto summonable_ = read_cardlist_spec(); auto summonable_ = read_cardlist_spec();
auto spsummon_ = read_cardlist_spec(); auto spsummon_ = read_cardlist_spec();
auto repos_ = read_cardlist_spec(); auto repos_ = read_cardlist_spec();
...@@ -4134,6 +4132,7 @@ private: ...@@ -4134,6 +4132,7 @@ private:
} }
} }
to_play_ = player;
callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset, callback_ = [this, spsummon_offset, repos_offset, mset_offset, set_offset,
activate_offset](int idx) { activate_offset](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
...@@ -4169,7 +4168,6 @@ private: ...@@ -4169,7 +4168,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_PLACE) { } else if (msg_ == MSG_SELECT_PLACE) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
if (count == 0) { if (count == 0) {
count = 1; count = 1;
...@@ -4189,6 +4187,7 @@ private: ...@@ -4189,6 +4187,7 @@ private:
" places for card, from " + specs_str + "."); " places for card, from " + specs_str + ".");
} }
} }
to_play_ = player;
callback_ = [this, player](int idx) { callback_ = [this, player](int idx) {
int y = player + 1; int y = player + 1;
std::string spec = options_[idx]; std::string spec = options_[idx];
...@@ -4205,7 +4204,6 @@ private: ...@@ -4205,7 +4204,6 @@ private:
}; };
} else if (msg_ == MSG_SELECT_DISFIELD) { } else if (msg_ == MSG_SELECT_DISFIELD) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
if (count == 0) { if (count == 0) {
count = 1; count = 1;
...@@ -4225,6 +4223,7 @@ private: ...@@ -4225,6 +4223,7 @@ private:
std::to_string(count) + " not implemented"); std::to_string(count) + " not implemented");
} }
} }
to_play_ = player;
callback_ = [this, player](int idx) { callback_ = [this, player](int idx) {
int y = player + 1; int y = player + 1;
std::string spec = options_[idx]; std::string spec = options_[idx];
...@@ -4241,7 +4240,6 @@ private: ...@@ -4241,7 +4240,6 @@ private:
}; };
} else if (msg_ == MSG_ANNOUNCE_NUMBER) { } else if (msg_ == MSG_ANNOUNCE_NUMBER) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
std::vector<int> numbers; std::vector<int> numbers;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
...@@ -4265,12 +4263,12 @@ private: ...@@ -4265,12 +4263,12 @@ private:
str += "]"; str += "]";
pl->notify(str); pl->notify(str);
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
OCG_SetResponsei(pduel_, idx); OCG_SetResponsei(pduel_, idx);
}; };
} else if (msg_ == MSG_ANNOUNCE_ATTRIB) { } else if (msg_ == MSG_ANNOUNCE_ATTRIB) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto count = read_u8(); auto count = read_u8();
auto flag = read_u32(); auto flag = read_u32();
...@@ -4310,6 +4308,7 @@ private: ...@@ -4310,6 +4308,7 @@ private:
options_.push_back(option); options_.push_back(option);
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
const auto &option = options_[idx]; const auto &option = options_[idx];
uint32_t resp = 0; uint32_t resp = 0;
...@@ -4323,7 +4322,6 @@ private: ...@@ -4323,7 +4322,6 @@ private:
} else if (msg_ == MSG_SELECT_POSITION) { } else if (msg_ == MSG_SELECT_POSITION) {
auto player = read_u8(); auto player = read_u8();
to_play_ = player;
auto code = read_u32(); auto code = read_u32();
auto valid_pos = read_u8(); auto valid_pos = read_u8();
...@@ -4348,6 +4346,7 @@ private: ...@@ -4348,6 +4346,7 @@ private:
i++; i++;
} }
to_play_ = player;
callback_ = [this](int idx) { callback_ = [this](int idx) {
uint8_t pos = options_[idx][0] - '1'; uint8_t pos = options_[idx][0] - '1';
OCG_SetResponsei(pduel_, 1 << pos); OCG_SetResponsei(pduel_, 1 << pos);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment