Commit 80707a8c authored by Biluo Shen's avatar Biluo Shen

add OSFP

parent 11261948
*.pt *.pt
*.ptj
*.pkl *.pkl
# Xmake cache # Xmake cache
......
...@@ -3,7 +3,7 @@ import random ...@@ -3,7 +3,7 @@ import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Optional
import ygoenv import ygoenv
...@@ -52,10 +52,8 @@ class Args: ...@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
...@@ -74,9 +72,9 @@ class Args: ...@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.98
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
minibatch_size: int = 256 minibatch_size: int = 256
...@@ -85,7 +83,7 @@ class Args: ...@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = True
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.1 clip_coef: float = 0.2
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
clip_vloss: bool = True clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper.""" """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...@@ -93,17 +91,13 @@ class Args: ...@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 0.5 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None learn_opponent: bool = False
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent""" """if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -125,7 +119,7 @@ class Args: ...@@ -125,7 +119,7 @@ class Args:
"""the probability of logging""" """the probability of logging"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10 eval_interval: int = 50
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
...@@ -143,6 +137,23 @@ class Args: ...@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main(): def main():
rank = int(os.environ.get("RANK", 0)) rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
...@@ -169,7 +180,7 @@ def main(): ...@@ -169,7 +180,7 @@ def main():
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
torchrun_setup(args.backend, local_rank) torchrun_setup('nccl', local_rank)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...@@ -204,43 +215,17 @@ def main(): ...@@ -204,43 +215,17 @@ def main():
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
# env setup # env setup
envs = ygoenv.make( envs = make_env(args, args.local_num_envs, local_env_threads)
task_id=args.env_id, obs_space = envs.env.observation_space
env_type="gymnasium", action_shape = envs.env.action_space.shape
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make( local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
task_id=args.env_id, eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -312,10 +297,10 @@ def main(): ...@@ -312,10 +297,10 @@ def main():
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0 step = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
...@@ -372,7 +357,7 @@ def main(): ...@@ -372,7 +357,7 @@ def main():
if random.random() < args.log_p: if random.random() < args.log_p:
n = 100 n = 100
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
...@@ -394,7 +379,7 @@ def main(): ...@@ -394,7 +379,7 @@ def main():
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps): for v_start in range(0, step, v_steps):
...@@ -421,8 +406,11 @@ def main(): ...@@ -421,8 +406,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -444,9 +432,6 @@ def main(): ...@@ -444,9 +432,6 @@ def main():
scaler.update() scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0: if step > 0:
# TODO: use cyclic buffer to avoid copying # TODO: use cyclic buffer to avoid copying
for v in obs.values(): for v in obs.values():
...@@ -463,7 +448,6 @@ def main(): ...@@ -463,7 +448,6 @@ def main():
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
...@@ -490,7 +474,7 @@ def main(): ...@@ -490,7 +474,7 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate( eval_return = evaluate(
......
import os import os
import random import random
import time import time
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Optional
import ygoenv import ygoenv
...@@ -52,10 +51,8 @@ class Args: ...@@ -52,10 +51,8 @@ class Args:
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
...@@ -74,15 +71,21 @@ class Args: ...@@ -74,15 +71,21 @@ class Args:
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.98
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
update_win_rate: float = 0.55 update_win_rate: float = 0.55
"""the required win rate to update the agent""" """the required win rate to update the agent"""
update_return: float = 0.1 self_play_prob: float = 0.6
"""the required return to update the agent""" """the probability of self play"""
max_lp: int = 6
"""the maximum number of LP to add model to the pool"""
iter_per_lp: int = 1000
"""the number of iterations per learning phase"""
target_sample_iter: int = 10
"""the number of iterations to sample the target model"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
...@@ -90,7 +93,7 @@ class Args: ...@@ -90,7 +93,7 @@ class Args:
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = True
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.1 clip_coef: float = 0.2
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
clip_vloss: bool = True clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper.""" """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...@@ -98,17 +101,13 @@ class Args: ...@@ -98,17 +101,13 @@ class Args:
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 0.5 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None learn_opponent: bool = False
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent""" """if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -130,7 +129,7 @@ class Args: ...@@ -130,7 +129,7 @@ class Args:
"""the probability of logging""" """the probability of logging"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10 eval_interval: int = 50
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
...@@ -148,6 +147,27 @@ class Args: ...@@ -148,6 +147,27 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def update_running_mean(mean, value, count):
return mean + (value - mean) / count
def main(): def main():
rank = int(os.environ.get("RANK", 0)) rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
...@@ -174,7 +194,19 @@ def main(): ...@@ -174,7 +194,19 @@ def main():
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
torchrun_setup(args.backend, local_rank) torchrun_setup('nccl', local_rank)
def sync_var(var, dtype=torch.float32, reduce='first'):
ts = torch.tensor(var, dtype=dtype, device=device)
if reduce == 'mean':
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.AVG)
else:
if rank != 0:
ts = torch.zeros_like(ts)
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.SUM)
return ts.cpu().numpy()
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...@@ -209,43 +241,17 @@ def main(): ...@@ -209,43 +241,17 @@ def main():
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
# env setup # env setup
envs = ygoenv.make( envs = make_env(args, args.local_num_envs, local_env_threads)
task_id=args.env_id, obs_space = envs.env.observation_space
env_type="gymnasium", action_shape = envs.env.action_space.shape
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make( local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
task_id=args.env_id, eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -280,6 +286,8 @@ def main(): ...@@ -280,6 +286,8 @@ def main():
logits, value, valid = agent(next_obs) logits, value, valid = agent(next_obs)
return logits, value return logits, value
history = []
from ygoai.rl.ppo import train_step from ygoai.rl.ppo import train_step
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
...@@ -287,13 +295,23 @@ def main(): ...@@ -287,13 +295,23 @@ def main():
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device) example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.checkpoint:
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False) traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t) traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
def sample_target(history):
ts = []
for i in range(args.target_sample_iter):
if len(history) == 0 or random.random() < args.self_play_prob:
ts.append(-1)
else: else:
traced_model = agent ts.append(random.randint(0, len(history) - 1))
traced_model_t = agent_t ts.sort(reverse=True)
return sync_var(ts, dtype=torch.int64).tolist()
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device) obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
...@@ -303,9 +321,9 @@ def main(): ...@@ -303,9 +321,9 @@ def main():
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device) dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device) values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device) learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = [0]
avg_win_rates = deque(maxlen=1000) avg_win_rates = [0]
version = 0 n_episodes = [0]
# TRY NOT TO MODIFY: start the game # TRY NOT TO MODIFY: start the game
global_step = 0 global_step = 0
...@@ -324,14 +342,23 @@ def main(): ...@@ -324,14 +342,23 @@ def main():
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0 step = 0
lp_count = 0
ts = sample_target(history)
for iteration in range(1, args.num_iterations + 1): for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations frac = 1.0 - (iteration % args.iter_per_lp) / args.iter_per_lp
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
if len(ts) == 0:
ts = sample_target(history)
t_idx = ts.pop()
selfplay = t_idx == -1
if not selfplay:
traced_model_t = history[t_idx]
model_time = 0 model_time = 0
env_time = 0 env_time = 0
collect_start = time.time() collect_start = time.time()
...@@ -346,6 +373,7 @@ def main(): ...@@ -346,6 +373,7 @@ def main():
_start = time.time() _start = time.time()
logits, value = predict_step(traced_model, next_obs) logits, value = predict_step(traced_model, next_obs)
if not selfplay:
logits_t, value_t = predict_step(traced_model_t, next_obs) logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t) logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t) value = torch.where(learn[:, None], value, value_t)
...@@ -374,21 +402,20 @@ def main(): ...@@ -374,21 +402,20 @@ def main():
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool) next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1 step += 1
if not writer:
continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1 pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) if len(history) == 0 or not selfplay:
avg_win_rates.append(win) n_episodes[t_idx] += 1
avg_ep_returns[t_idx] = update_running_mean(avg_ep_returns[t_idx], episode_reward, n_episodes[t_idx])
avg_win_rates[t_idx] = update_running_mean(avg_win_rates[t_idx], win, n_episodes[t_idx])
if random.random() < args.log_p: if writer and random.random() < args.log_p:
n = 100 n = 100
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
...@@ -407,12 +434,13 @@ def main(): ...@@ -407,12 +434,13 @@ def main():
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1) value = predict_step(traced_model, next_obs)[1].reshape(-1)
if not selfplay:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1) value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t) value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps): for v_start in range(0, step, v_steps):
...@@ -439,8 +467,11 @@ def main(): ...@@ -439,8 +467,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.learn_opponent or selfplay:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -462,9 +493,6 @@ def main(): ...@@ -462,9 +493,6 @@ def main():
scaler.update() scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0: if step > 0:
# TODO: use cyclic buffer to avoid copying # TODO: use cyclic buffer to avoid copying
for v in obs.values(): for v in obs.values():
...@@ -481,7 +509,6 @@ def main(): ...@@ -481,7 +509,6 @@ def main():
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
...@@ -508,27 +535,29 @@ def main(): ...@@ -508,27 +535,29 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0: if (iteration + 1) % args.iter_per_lp == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return lp_count += 1
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device) win_rates = sync_var(avg_win_rates, dtype=torch.float32, reduce='mean')
else: if np.all(win_rates > args.update_win_rate) or lp_count >= args.max_lp:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict()) agent_t.load_state_dict(agent.state_dict())
with torch.no_grad(): with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False) traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t) traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
version += 1 lp_count = 0
if rank == 0: if rank == 0:
version = len(history)
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}") fprint(f"model v{version} added to the pool, win_rates={win_rates}")
avg_win_rates.clear() else:
avg_ep_returns.clear() if rank == 0:
fprint(f"win_rates={win_rates}, not updating the pool")
avg_ep_returns = [0] * len(history)
avg_win_rates = [0] * len(history)
n_episodes = [0] * len(history)
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate( eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0] eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment