Commit 80707a8c authored by Biluo Shen's avatar Biluo Shen

add OSFP

parent 11261948
*.pt
*.ptj
*.pkl
# Xmake cache
......
......@@ -3,7 +3,7 @@ import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional
import ygoenv
......@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
gae_lambda: float = 0.98
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
......@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
......@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = True
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
......@@ -125,7 +119,7 @@ class Args:
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
......@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
......@@ -169,7 +180,7 @@ def main():
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
torchrun_setup('nccl', local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
......@@ -204,43 +215,17 @@ def main():
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -312,10 +297,10 @@ def main():
next_value1 = next_value2 = 0
step = 0
for iteration in range(1, args.num_iterations + 1):
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
......@@ -372,7 +357,7 @@ def main():
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
......@@ -394,7 +379,7 @@ def main():
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1:
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
......@@ -421,8 +406,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -444,9 +432,6 @@ def main():
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
......@@ -463,7 +448,6 @@ def main():
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
......@@ -490,7 +474,7 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment