Commit ad9980e5 authored by biluo.shen's avatar biluo.shen

Add battle

parent e5e5402a
edopro-core @ 8c623744
Subproject commit 8c6237444e294b730bce1eccc6fab2721b7cbea9
...@@ -65,7 +65,7 @@ class Args: ...@@ -65,7 +65,7 @@ class Args:
checkpoint2: Optional[str] = "checkpoints/agent.pt" checkpoint2: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent""" """the checkpoint to load for the second agent"""
compile: bool = True compile: bool = False
"""if toggled, the model will be compiled""" """if toggled, the model will be compiled"""
optimize: bool = False optimize: bool = False
"""if toggled, the model will be optimized""" """if toggled, the model will be optimized"""
...@@ -130,33 +130,37 @@ if __name__ == "__main__": ...@@ -130,33 +130,37 @@ if __name__ == "__main__":
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
embedding_shape = args.num_embeddings if args.checkpoint1.endswith(".ptj"):
if embedding_shape is None: agent1 = torch.jit.load(args.checkpoint1)
with open(args.code_list_file, "r") as f: agent2 = torch.jit.load(args.checkpoint2)
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
else: else:
if args.optimize: embedding_shape = args.num_embeddings
obs = create_obs(envs.observation_space, (num_envs,), device=device) if embedding_shape is None:
def optimize_for_inference(agent): with open(args.code_list_file, "r") as f:
with torch.no_grad(): code_list = f.readlines()
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) embedding_shape = len(code_list)
return torch.jit.optimize_for_inference(traced_model) L = args.num_layers
agent1 = optimize_for_inference(agent1) agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = optimize_for_inference(agent2) agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
else:
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play_ = infos['to_play'] next_to_play_ = infos['to_play']
......
...@@ -80,6 +80,9 @@ class Args: ...@@ -80,6 +80,9 @@ class Args:
"""if toggled, the model will be compiled""" """if toggled, the model will be compiled"""
optimize: bool = True optimize: bool = True
"""if toggled, the model will be optimized""" """if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = 16
...@@ -156,6 +159,21 @@ if __name__ == "__main__": ...@@ -156,6 +159,21 @@ if __name__ == "__main__":
print(agent.load_state_dict(state_dict)) print(agent.load_state_dict(state_dict))
if args.compile: if args.compile:
if args.convert:
# Don't support dynamic shapes and very slow inference
raise NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent = torch.compile(agent, mode='reduce-overhead') agent = torch.compile(agent, mode='reduce-overhead')
elif args.optimize: elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device) obs = create_obs(envs.observation_space, (num_envs,), device=device)
...@@ -164,6 +182,10 @@ if __name__ == "__main__": ...@@ -164,6 +182,10 @@ if __name__ == "__main__":
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model) return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent) agent = optimize_for_inference(agent)
if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
......
...@@ -425,7 +425,7 @@ def run(local_rank, world_size): ...@@ -425,7 +425,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if local_rank == 0:
if iteration % args.save_interval == 0 or iteration == args.num_iterations: if iteration % args.save_interval == 0 or iteration == args.num_iterations:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......
...@@ -21,7 +21,7 @@ from torch.cuda.amp import GradScaler, autocast ...@@ -21,7 +21,7 @@ from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, mp_start, setup, fprint from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
...@@ -118,8 +118,6 @@ class Args: ...@@ -118,8 +118,6 @@ class Args:
"""the number of iterations to save the model""" """the number of iterations to save the model"""
log_p: float = 1.0 log_p: float = 1.0
"""the probability of logging""" """the probability of logging"""
port: int = 12356
"""the port to use for distributed training"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10 eval_interval: int = 10
...@@ -140,7 +138,12 @@ class Args: ...@@ -140,7 +138,12 @@ class Args:
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
def run(local_rank, world_size): def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args) args = tyro.cli(Args)
args.world_size = world_size args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size args.local_num_envs = args.num_envs // args.world_size
...@@ -158,12 +161,12 @@ def run(local_rank, world_size): ...@@ -158,12 +161,12 @@ def run(local_rank, world_size):
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port) torchrun_setup(args.backend, local_rank)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None writer = None
if local_rank == 0: if rank == 0:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name)) writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text( writer.add_text(
...@@ -177,10 +180,10 @@ def run(local_rank, world_size): ...@@ -177,10 +180,10 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: seeding # TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker # CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank args.seed += rank
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank) torch.manual_seed(args.seed - rank)
if args.torch_deterministic: if args.torch_deterministic:
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
else: else:
...@@ -188,7 +191,7 @@ def run(local_rank, world_size): ...@@ -188,7 +191,7 @@ def run(local_rank, world_size):
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu") device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro("english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
...@@ -429,7 +432,8 @@ def run(local_rank, world_size): ...@@ -429,7 +432,8 @@ def run(local_rank, world_size):
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start collect_time = time.time() - collect_start
fprint(f"[Rank {local_rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
...@@ -561,16 +565,17 @@ def run(local_rank, world_size): ...@@ -561,16 +565,17 @@ def run(local_rank, world_size):
train_time = time.time() - _start train_time = time.time() - _start
fprint(f"[Rank {local_rank}] train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}") if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pth")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
...@@ -581,15 +586,17 @@ def run(local_rank, world_size): ...@@ -581,15 +586,17 @@ def run(local_rank, world_size):
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step) writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time)) SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement # Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10 SPS_warmup_iters = 10
if iteration == SPS_warmup_iters: if iteration == SPS_warmup_iters:
start_time = time.time() start_time = time.time()
warmup_steps = global_step warmup_steps = global_step
if iteration > SPS_warmup_iters: if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}") fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0: if iteration % args.eval_interval == 0:
...@@ -628,11 +635,12 @@ def run(local_rank, world_size): ...@@ -628,11 +635,12 @@ def run(local_rank, world_size):
# sync the statistics # sync the statistics
if args.world_size > 1: if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0: eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy() if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step) writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step) writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step) writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
if local_rank == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}") fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
...@@ -641,10 +649,10 @@ def run(local_rank, world_size): ...@@ -641,10 +649,10 @@ def run(local_rank, world_size):
if args.world_size > 1: if args.world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
if local_rank == 0: if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close() writer.close()
if __name__ == "__main__": if __name__ == "__main__":
mp_start(run) main()
This diff is collapsed.
...@@ -530,7 +530,7 @@ def run(local_rank, world_size): ...@@ -530,7 +530,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if local_rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pth")) torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
...@@ -564,7 +564,7 @@ def run(local_rank, world_size): ...@@ -564,7 +564,7 @@ def run(local_rank, world_size):
agent2.load_state_dict(agent1.state_dict()) agent2.load_state_dict(agent1.state_dict())
version += 1 version += 1
if local_rank == 0: if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pth")) torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}") print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear() avg_win_rates.clear()
avg_ep_returns.clear() avg_ep_returns.clear()
...@@ -614,7 +614,7 @@ def run(local_rank, world_size): ...@@ -614,7 +614,7 @@ def run(local_rank, world_size):
dist.destroy_process_group() dist.destroy_process_group()
envs.close() envs.close()
if local_rank == 0: if local_rank == 0:
torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pth")) torch.save(agent1.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close() writer.close()
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment