Commit 80707a8c authored by Biluo Shen's avatar Biluo Shen

add OSFP

parent 11261948
*.pt
*.ptj
*.pkl
# Xmake cache
......
......@@ -3,7 +3,7 @@ import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional
import ygoenv
......@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
gae_lambda: float = 0.98
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
......@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
......@@ -125,7 +119,7 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
......@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
......@@ -169,7 +180,7 @@ def main():
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
torchrun_setup('nccl', local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
......@@ -204,43 +215,17 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -312,10 +297,10 @@ def main():
next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
......@@ -372,7 +357,7 @@ def main():
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -394,7 +379,7 @@ def main():
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
......@@ -421,8 +406,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -444,9 +432,6 @@ def main():
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
......@@ -463,7 +448,6 @@ def main():
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
......@@ -490,7 +474,7 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
......
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional
import ygoenv
......@@ -52,10 +51,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -74,15 +71,21 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
gae_lambda: float = 0.98
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
self_play_prob: float = 0.6
"""the probability of self play"""
max_lp: int = 6
"""the maximum number of LP to add model to the pool"""
iter_per_lp: int = 1000
"""the number of iterations per learning phase"""
target_sample_iter: int = 10
"""the number of iterations to sample the target model"""
minibatch_size: int = 256
"""the mini-batch size"""
......@@ -90,7 +93,7 @@ class Args:
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -98,17 +101,13 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
......@@ -130,7 +129,7 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
......@@ -148,6 +147,27 @@ class Args:
"""the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def update_running_mean(mean, value, count):
return mean + (value - mean) / count
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
......@@ -174,7 +194,19 @@ def main():
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
torchrun_setup('nccl', local_rank)
def sync_var(var, dtype=torch.float32, reduce='first'):
ts = torch.tensor(var, dtype=dtype, device=device)
if reduce == 'mean':
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.AVG)
else:
if rank != 0:
ts = torch.zeros_like(ts)
if args.world_size > 1:
dist.all_reduce(ts, op=dist.ReduceOp.SUM)
return ts.cpu().numpy()
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
......@@ -209,43 +241,17 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -280,6 +286,8 @@ def main():
logits, value, valid = agent(next_obs)
return logits, value
history = []
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
......@@ -287,13 +295,23 @@ def main():
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
if args.checkpoint:
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
def sample_target(history):
ts = []
for i in range(args.target_sample_iter):
if len(history) == 0 or random.random() < args.self_play_prob:
ts.append(-1)
else:
ts.append(random.randint(0, len(history) - 1))
ts.sort(reverse=True)
return sync_var(ts, dtype=torch.int64).tolist()
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
......@@ -303,9 +321,9 @@ def main():
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
avg_ep_returns = [0]
avg_win_rates = [0]
n_episodes = [0]
# TRY NOT TO MODIFY: start the game
global_step = 0
......@@ -324,14 +342,23 @@ def main():
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
lp_count = 0
ts = sample_target(history)
for iteration in range(1, args.num_iterations + 1):
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
frac = 1.0 - (iteration % args.iter_per_lp) / args.iter_per_lp
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
if len(ts) == 0:
ts = sample_target(history)
t_idx = ts.pop()
selfplay = t_idx == -1
if not selfplay:
traced_model_t = history[t_idx]
model_time = 0
env_time = 0
collect_start = time.time()
......@@ -346,9 +373,10 @@ def main():
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
if not selfplay:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -374,21 +402,20 @@ def main():
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if len(history) == 0 or not selfplay:
n_episodes[t_idx] += 1
avg_ep_returns[t_idx] = update_running_mean(avg_ep_returns[t_idx], episode_reward, n_episodes[t_idx])
avg_win_rates[t_idx] = update_running_mean(avg_win_rates[t_idx], win, n_episodes[t_idx])
if random.random() < args.log_p:
if writer and random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -407,12 +434,13 @@ def main():
# bootstrap value if not done
with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1)
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
if not selfplay:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
......@@ -439,8 +467,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent or selfplay:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -462,9 +493,6 @@ def main():
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
......@@ -481,7 +509,6 @@ def main():
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
......@@ -508,27 +535,29 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if (iteration + 1) % args.iter_per_lp == 0:
lp_count += 1
win_rates = sync_var(avg_win_rates, dtype=torch.float32, reduce='mean')
if np.all(win_rates > args.update_win_rate) or lp_count >= args.max_lp:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
lp_count = 0
if rank == 0:
version = len(history)
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
fprint(f"model v{version} added to the pool, win_rates={win_rates}")
else:
if rank == 0:
fprint(f"win_rates={win_rates}, not updating the pool")
avg_ep_returns = [0] * len(history)
avg_win_rates = [0] * len(history)
n_episodes = [0] * len(history)
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment