Commit 80707a8c authored by Biluo Shen's avatar Biluo Shen

add OSFP

parent 11261948
*.pt *.pt
*.ptj
*.pkl *.pkl
# Xmake cache # Xmake cache
......
...@@ -3,7 +3,7 @@ import random ...@@ -3,7 +3,7 @@ import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Optional
import ygoenv import ygoenv
...@@ -52,10 +52,8 @@ class Args: ...@@ -52,10 +52,8 @@ class Args:
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
...@@ -74,9 +72,9 @@ class Args: ...@@ -74,9 +72,9 @@ class Args:
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.98
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
minibatch_size: int = 256 minibatch_size: int = 256
...@@ -85,7 +83,7 @@ class Args: ...@@ -85,7 +83,7 @@ class Args:
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = True
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.1 clip_coef: float = 0.2
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
clip_vloss: bool = True clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper.""" """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...@@ -93,17 +91,13 @@ class Args: ...@@ -93,17 +91,13 @@ class Args:
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 0.5 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None learn_opponent: bool = False
"""the target KL divergence threshold"""
learn_opponent: bool = True
"""if toggled, the samples from the opponent will be used to train the agent""" """if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -125,7 +119,7 @@ class Args: ...@@ -125,7 +119,7 @@ class Args:
"""the probability of logging""" """the probability of logging"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10 eval_interval: int = 50
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
...@@ -143,6 +137,23 @@ class Args: ...@@ -143,6 +137,23 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main(): def main():
rank = int(os.environ.get("RANK", 0)) rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
...@@ -169,7 +180,7 @@ def main(): ...@@ -169,7 +180,7 @@ def main():
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
torchrun_setup(args.backend, local_rank) torchrun_setup('nccl', local_rank)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...@@ -204,43 +215,17 @@ def main(): ...@@ -204,43 +215,17 @@ def main():
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
# env setup # env setup
envs = ygoenv.make( envs = make_env(args, args.local_num_envs, local_env_threads)
task_id=args.env_id, obs_space = envs.env.observation_space
env_type="gymnasium", action_shape = envs.env.action_space.shape
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make( local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
task_id=args.env_id, eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -312,10 +297,10 @@ def main(): ...@@ -312,10 +297,10 @@ def main():
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0 step = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
...@@ -372,7 +357,7 @@ def main(): ...@@ -372,7 +357,7 @@ def main():
if random.random() < args.log_p: if random.random() < args.log_p:
n = 100 n = 100
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
...@@ -394,7 +379,7 @@ def main(): ...@@ -394,7 +379,7 @@ def main():
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 1: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps): for v_start in range(0, step, v_steps):
...@@ -421,8 +406,11 @@ def main(): ...@@ -421,8 +406,11 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -444,9 +432,6 @@ def main(): ...@@ -444,9 +432,6 @@ def main():
scaler.update() scaler.update()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
if step > 0: if step > 0:
# TODO: use cyclic buffer to avoid copying # TODO: use cyclic buffer to avoid copying
for v in obs.values(): for v in obs.values():
...@@ -463,7 +448,6 @@ def main(): ...@@ -463,7 +448,6 @@ def main():
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
...@@ -490,7 +474,7 @@ def main(): ...@@ -490,7 +474,7 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate( eval_return = evaluate(
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment