Commit bbd36d86 authored by sbl1996@126.com's avatar sbl1996@126.com

replace torch with jax as default

parent 43ca871e
...@@ -4,16 +4,20 @@ import os ...@@ -4,16 +4,20 @@ import os
import random import random
from typing import Optional, Literal from typing import Optional, Literal
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm
import ygoenv import ygoenv
import numpy as np import numpy as np
import optree
import tyro import tyro
import jax
import jax.numpy as jnp
import flax
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent
@dataclass @dataclass
...@@ -54,6 +58,8 @@ class Args: ...@@ -54,6 +58,8 @@ class Args:
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt" checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file""" """the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt" checkpoint2: str = "checkpoints/agent.pt"
...@@ -63,25 +69,30 @@ class Args: ...@@ -63,25 +69,30 @@ class Args:
xla_device: Optional[str] = None xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`""" """the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def create_agent(args):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
if __name__ == "__main__": if __name__ == "__main__":
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args = tyro.cli(Args) args = tyro.cli(Args)
if args.record: if args.record:
...@@ -101,18 +112,6 @@ if __name__ == "__main__": ...@@ -101,18 +112,6 @@ if __name__ == "__main__":
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
if args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint1 else "torch"
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
else:
if args.xla_device is not None: if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device) os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
...@@ -138,78 +137,14 @@ if __name__ == "__main__": ...@@ -138,78 +137,14 @@ if __name__ == "__main__":
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.framework == 'torch':
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint1.endswith(".ptj"):
agent1 = torch.jit.load(args.checkpoint1)
agent2 = torch.jit.load(args.checkpoint2)
else:
# count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
def get_probs(agent, obs):
with torch.no_grad():
return torch.softmax(agent(obs)[0], dim=-1)
if args.compile:
get_probs = torch.compile(get_probs, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
def predict_fn(obs, main):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
if num_envs != 1:
probs1 = get_probs(agent, obs)
probs2 = get_probs(agent, obs)
probs = torch.where(main[:, None], probs1, probs2)
else:
if main[0]:
probs = get_probs(agent1, obs)
else:
probs = get_probs(agent2, obs)
probs = probs.cpu().numpy()
return probs
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args) agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
print(jax.tree.leaves(params)[0].devices()) rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f: with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read()) params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2: if args.checkpoint1 == args.checkpoint2:
...@@ -218,38 +153,51 @@ if __name__ == "__main__": ...@@ -218,38 +153,51 @@ if __name__ == "__main__":
with open(args.checkpoint2, "rb") as f: with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read()) params2 = flax.serialization.from_bytes(params, f.read())
params1 = jax.device_put(params1)
params2 = jax.device_put(params2)
@jax.jit @jax.jit
def get_probs(params, obs): def get_probs(params, rstate, obs, done):
agent = create_agent(args) agent = create_agent(args)
logits = agent.apply(params, obs)[0] next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
return jax.nn.softmax(logits, axis=-1) probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
if args.num_envs != 1: if args.num_envs != 1:
@jax.jit @jax.jit
def get_probs2(params1, params2, obs, main): def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
probs1 = get_probs(params1, obs) next_rstate1, probs1 = get_probs(params1, rstate1, obs, done)
probs2 = get_probs(params2, obs) next_rstate2, probs2 = get_probs(params2, rstate2, obs, done)
probs = jnp.where(main[:, None], probs1, probs2) probs = jnp.where(main[:, None], probs1, probs2)
return probs rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
def predict_fn(obs, main): rstate2 = jax.tree.map(
probs = get_probs2(params1, params2, obs, main) lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return np.array(probs) return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done):
rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs, main, done)
return rstate1, rstate2, np.array(probs)
else: else:
def predict_fn(obs, main): def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]: if main[0]:
probs = get_probs(params1, obs) rstate1, probs = get_probs(params1, rstate1, obs, done)
else: else:
probs = get_probs(params2, obs) rstate2, probs = get_probs(params2, rstate2, obs, done)
return np.array(probs) return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
win_rates = [] win_rates = []
win_reasons = [] win_reasons = []
win_players = []
win_agents = []
step = 0 step = 0
start = time.time() start = time.time()
...@@ -259,6 +207,10 @@ if __name__ == "__main__": ...@@ -259,6 +207,10 @@ if __name__ == "__main__":
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64) np.ones(num_envs - num_envs // 2, dtype=np.int64)
]) ])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
...@@ -268,9 +220,8 @@ if __name__ == "__main__": ...@@ -268,9 +220,8 @@ if __name__ == "__main__":
model_time = env_time = 0 model_time = env_time = 0
_start = time.time() _start = time.time()
main = next_to_play == main_player main = next_to_play == main_player
probs = predict_fn(obs, main) rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs, main, dones)
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
...@@ -289,14 +240,24 @@ if __name__ == "__main__": ...@@ -289,14 +240,24 @@ if __name__ == "__main__":
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == main_player[idx] else -1 pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl episode_reward = infos['r'][idx]
win = int(episode_reward > 0) main_reward = episode_reward * pl
win = int(main_reward > 0)
win_player = 0 if (to_play[idx] == 0 and episode_reward > 0) or (to_play[idx] == 1 and episode_reward < 0) else 1
win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent)
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(episode_reward) episode_rewards.append(main_reward)
win_rates.append(win) win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n") if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here # Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
...@@ -305,10 +266,38 @@ if __name__ == "__main__": ...@@ -305,10 +266,38 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}") print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
win_players = np.array(win_players)
win_agents = np.array(win_agents)
N = len(win_players)
N1 = np.sum((win_players == 0) & (win_agents == 1))
N2 = np.sum((win_players == 0) & (win_agents == 2))
N3 = np.sum((win_players == 1) & (win_agents == 1))
N4 = np.sum((win_players == 1) & (win_agents == 2))
print(f"Payoff matrix:")
w1 = N1 / N
w2 = N2 / N
w3 = N3 / N
w4 = N4 / N
print(f" agent1 agent2")
print(f"0 {w1:.4f} {w2:.4f}")
print(f"1 {w3:.4f} {w4:.4f}")
print(f"0/1 matrix, win rates of agentX as playerY")
w1 = N1 / (N1 + N4)
w2 = N2 / (N2 + N3)
w3 = 1 - w2
w4 = 1 - w1
print(f" agent1 agent2")
print(f"0 {w1:.4f} {w2:.4f}")
print(f"1 {w3:.4f} {w4:.4f}")
total_time = time.time() - start total_time = time.time() - start
total_steps = (step - start_step) * num_envs total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}") print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}") print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
...@@ -8,8 +8,6 @@ from dataclasses import dataclass ...@@ -8,8 +8,6 @@ from dataclasses import dataclass
import ygoenv import ygoenv
import numpy as np import numpy as np
import optree
import tyro import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
...@@ -63,6 +61,8 @@ class Args: ...@@ -63,6 +61,8 @@ class Args:
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file""" """the checkpoint to load, `pt` or `flax_model` file"""
...@@ -70,24 +70,24 @@ class Args: ...@@ -70,24 +70,24 @@ class Args:
xla_device: Optional[str] = None xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`""" """the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def create_agent(args):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -113,18 +113,6 @@ if __name__ == "__main__": ...@@ -113,18 +113,6 @@ if __name__ == "__main__":
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
if args.checkpoint and args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint else "torch"
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
else:
if args.xla_device is not None: if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device) os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
...@@ -150,102 +138,46 @@ if __name__ == "__main__": ...@@ -150,102 +138,46 @@ if __name__ == "__main__":
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.framework == 'torch': if args.checkpoint:
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
# Don't support dynamic shapes and very slow inference
raise NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent = torch.compile(agent, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent)
if args.convert:
torch.jit.save(agent, args.checkpoint + "j")
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
def predict_fn(obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits = agent(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
return probs
else:
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax import flax
from ygoai.rl.jax.agent2 import PPOAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
def create_agent(args): from jax.experimental.compilation_cache import compilation_cache as cc
return PPOAgent( cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params)
@jax.jit @jax.jit
def get_probs( def get_probs(params, rstate, obs, done):
params: flax.core.FrozenDict, agent = create_agent(args)
next_obs, next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
): probs = jax.nn.softmax(logits, axis=-1)
logits = create_agent(args).apply(params, next_obs)[0] next_rstate = jax.tree.map(
return jax.nn.softmax(logits, axis=-1) lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
def predict_fn(obs): def predict_fn(rstate, obs, done):
probs = get_probs(params, obs) rstate, probs = get_probs(params, rstate, obs, done)
return np.array(probs) return rstate, np.array(probs)
print(f"loaded checkpoint from {args.checkpoint}") print(f"loaded checkpoint from {args.checkpoint}")
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -263,9 +195,9 @@ if __name__ == "__main__": ...@@ -263,9 +195,9 @@ if __name__ == "__main__":
start_step = step start_step = step
model_time = env_time = 0 model_time = env_time = 0
if args.framework: if args.checkpoint:
_start = time.time() _start = time.time()
probs = predict_fn(obs) rstate, probs = predict_fn(rstate, obs, dones)
if args.verbose: if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()]) print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
...@@ -306,4 +238,3 @@ if __name__ == "__main__": ...@@ -306,4 +238,3 @@ if __name__ == "__main__":
total_steps = (step - start_step) * num_envs total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}") print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}") print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
from tqdm import tqdm
import ygoenv
import numpy as np
import tyro
import jax
import jax.numpy as jnp
import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent
@dataclass
class Args:
seed: int = 1
"""the random seed"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def create_agent(args):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
if __name__ == "__main__":
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=-1,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
params1 = jax.device_put(params1)
params2 = jax.device_put(params2)
@jax.jit
def get_probs(params, rstate, obs, done):
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs, done)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, done)
probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done):
rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs, main, done)
return rstate1, rstate2, np.array(probs)
else:
def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done)
else:
rstate2, probs = get_probs(params2, rstate2, obs, done)
return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
_start = time.time()
main = next_to_play == main_player
rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs, main, dones)
actions = probs.argmax(axis=1)
model_time += time.time() - _start
to_play = next_to_play
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here
if args.verbose:
main_player = 1 - main_player
if len(episode_lengths) >= args.num_episodes:
break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy, typically 0.02"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = False
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue,
eval_queue,
writer,
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=cached_main,
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
storage = []
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def ppo_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
new_values_, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_values, rewards, next_dones, switch),
)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
else:
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask)
ent_loss = entropy_loss(new_logits)
ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
params_queues = []
rollout_queues = []
eval_queue = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
f"{global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
import os import os
import queue
import random import random
import threading
import time import time
from datetime import datetime, timedelta, timezone
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np import numpy as np
import optax
import distrax
import tyro import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate
@dataclass @dataclass
class Args: class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")] exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment""" """the name of this experiment"""
seed: int = 1 seed: int = 1
"""seed of the experiment""" """seed of the experiment"""
torch_deterministic: bool = False log_frequency: int = 10
"""if toggled, `torch.backends.cudnn.deterministic=False`""" """the logging frequency of the model performance (in terms of `updates`)"""
cuda: bool = True save_interval: int = 400
"""if toggled, cuda will be enabled by default""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
...@@ -54,499 +63,806 @@ class Args: ...@@ -54,499 +63,806 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 32 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = True
"""whether to use greedy reward (faster kill higher reward)"""
num_layers: int = 2 total_timesteps: int = 5000000000
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 1e-3
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
num_envs: int = 8 local_num_envs: int = 128
"""the number of parallel game environments""" """the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
upgo: bool = False
fix_target: bool = False """Toggle the use of UPGO for advantages"""
"""if toggled, the target network will be fixed""" num_minibatches: int = 8
update_win_rate: float = 0.55 """the number of mini-batches"""
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = False
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.2 clip_coef: float = 0.25
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
clip_vloss: bool = True dual_clip_coef: Optional[float] = None
"""Toggles whether or not to use a clipped loss for the value function, as per the paper.""" """the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy, typically 0.02"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" num_layers: int = 2
"""the number of layers for the agent"""
compile: Optional[str] = None num_channels: int = 128
"""Compile mode of torch.compile, None for no compilation""" """the number of channels for the agent"""
torch_threads: Optional[int] = None rnn_channels: int = 512
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of channels for the RNN in the agent"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`""" actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
fp16_train: bool = False """the device ids that actor workers will use"""
"""if toggled, training will be done in fp16 precision""" learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
fp16_eval: bool = False """the device ids that learner workers will use"""
"""if toggled, evaluation will be done in fp16 precision""" distributed: bool = False
"""whether to use `jax.distirbuted`"""
tb_dir: str = "./runs" concurrency: bool = True
"""tensorboard log directory""" """whether to run the actor and learner concurrently"""
ckpt_dir: str = "./checkpoints" bfloat16: bool = False
"""checkpoint directory""" """whether to use bfloat16 for the agent"""
save_interval: int = 500 thread_affinity: bool = False
"""the number of iterations to save the model""" """whether to use thread affinity for the environment"""
log_p: float = 1.0
"""the probability of logging""" eval_checkpoint: Optional[str] = None
eval_episodes: int = 128 """the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 50
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# to be filled in runtime # runtime arguments to be filled in
local_batch_size: int = 0 local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0 local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0 world_size: int = 0
"""the number of processes (computed in runtime)""" local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)""" freeze_id: bool = False
def make_env(args, num_envs, num_threads, mode='self'): def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make( envs = ygoenv.make(
task_id=args.env_id, task_id=args.env_id,
env_type="gymnasium", env_type="gymnasium",
num_envs=num_envs, num_envs=num_envs,
num_threads=num_threads, num_threads=num_threads,
seed=args.seed, thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1, deck1=args.deck1,
deck2=args.deck2, deck2=args.deck2,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs return envs
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads) class Transition(NamedTuple):
torch.set_float32_matmul_precision('high') obs: list
dones: list
if args.world_size > 1: actions: list
torchrun_setup('nccl', local_rank) logits: list
rewards: list
timestamp = int(time.time()) mains: list
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" next_dones: list
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter def create_agent(args, multi_step=False):
writer = SummaryWriter(os.path.join(args.tb_dir, run_name)) return PPOLSTMAgent(
writer.add_text( channels=args.num_channels,
"hyperparameters", num_layers=args.num_layers,
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
freeze_id=args.freeze_id,
) )
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu") def init_rnn_state(num_envs, rnn_channels):
return (
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file) np.zeros((num_envs, rnn_channels)),
args.deck1 = args.deck1 or deck np.zeros((num_envs, rnn_channels)),
args.deck2 = args.deck2 or deck )
# env setup
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval()
if args.checkpoint: def rollout(
agent.load_state_dict(torch.load(args.checkpoint, map_location=device)) key: jax.random.PRNGKey,
fprint(f"Loaded checkpoint from {args.checkpoint}") args: Args,
elif args.embedding_file: rollout_queue,
agent.load_embeddings(embeddings) params_queue,
fprint(f"Loaded embeddings from {args.embedding_file}") eval_queue,
if args.embedding_file: writer,
agent.freeze_embeddings() learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
if args.fix_target: eval_envs = make_env(
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device) args,
agent_t.eval() args.seed + jax.process_index() + device_thread_id,
agent_t.load_state_dict(agent.state_dict()) args.local_eval_episodes,
else: args.local_eval_episodes // 4, mode=eval_mode, eval=True)
agent_t = agent eval_envs = RecordEpisodeStatistics(eval_envs)
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
return logits, value
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.fix_target:
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
else:
traced_model_t = traced_model
train_step = torch.compile(train_step_, mode=args.compile) len_actor_device_ids = len(args.actor_device_ids)
else: n_actors = args.num_actor_threads * len_actor_device_ids
traced_model = agent global_step = 0
traced_model_t = agent_t start_time = time.time()
train_step = train_step_ warmup_step = 0
other_time = 0
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game @jax.jit
global_step = 0 def get_logits(
warmup_steps = 0 params: flax.core.FrozenDict, inputs, done):
start_time = time.time() rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8) next_to_play = info["to_play"]
next_to_play_ = info["to_play"] next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_to_play = to_tensor(next_to_play_, device) next_rstate1 = next_rstate2 = init_rnn_state(
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) args.local_num_envs, args.rnn_channels)
main_player_ = np.concatenate([ eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(main_player_) np.random.shuffle(main_player)
main_player = to_tensor(main_player_, device, dtype=next_to_play.dtype) storage = []
step = 0
@jax.jit
for iteration in range(args.num_iterations): def prepare_data(storage: List[Transition]) -> Transition:
# Annealing the rate if instructed to do so. return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
if args.anneal_lr:
frac = 1.0 - iteration / args.num_iterations for update in range(1, args.num_updates + 2):
lrnow = frac * args.learning_rate if update == 10:
optimizer.param_groups[0]["lr"] = lrnow start_time = time.time()
warmup_step = global_step
model_time = 0
update_time_start = time.time()
inference_time = 0
env_time = 0 env_time = 0
collect_start = time.time() params_queue_get_time_start = time.time()
while step < args.collect_length: if args.concurrency:
global_step += args.num_envs if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
for key in obs: rollout_time_start = time.time()
obs[key][step] = next_obs[key] init_rstate1, init_rstate2 = jax.tree.map(
dones[step] = next_done lambda x: x.copy(), (next_rstate1, next_rstate2))
learn = next_to_play == main_player for _ in range(args.num_steps):
learns[step] = learn global_step += args.local_num_envs * n_actors * args.world_size
_start = time.time() cached_next_obs = next_obs
logits, value = predict_step(traced_model, next_obs) cached_next_done = next_done
if args.fix_target: main = next_to_play == main_player
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t) inference_time_start = time.time()
value = torch.where(learn[:, None], value, value_t) cached_next_obs, cached_next_done, cached_main, \
value = value.flatten() next_rstate1, next_rstate2, action, logits, key = sample_action(
probs = Categorical(logits=logits) params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
action = probs.sample()
logprob = probs.log_prob(action) cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
_start = time.time() _start = time.time()
to_play = next_to_play_ next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_obs, reward, next_done_, info = envs.step(action) next_to_play = info["to_play"]
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer: storage.append(
continue Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=cached_main,
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done):
if d: if not d:
pl = 1 if to_play[idx] == main_player_[idx] else -1 continue
episode_length = info['l'][idx] cur_main = main[idx]
episode_reward = info['r'][idx] * pl for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
avg_win_rates.append(win) avg_win_rates.append(win)
if random.random() < args.log_p: rollout_time.append(time.time() - rollout_time_start)
n = 100
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n: partitioned_storage = prepare_data(storage)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step) storage = []
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
_start = time.time() from jax.experimental.compilation_cache import compilation_cache as cc
# bootstrap value if not done cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == main_player, value, -value) args.world_size = jax.process_count()
if args.fix_target: args.local_rank = jax.process_index()
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1) args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
nextvalues2 = torch.where(next_to_play != main_player, value_t, -value_t) args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else: else:
nextvalues2 = -nextvalues1 embeddings = None
embedding_shape = None
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay( local_devices = jax.local_devices()
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda) global_devices = jax.devices()
bootstrap_time = time.time() - _start learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
_start = time.time() timestamp = int(time.time())
# flatten the batch run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
b_obs = {
k: v[:args.num_steps].reshape((-1,) + v.shape[2:]) writer = SummaryWriter(f"runs/{run_name}")
for k, v in obs.items() writer.add_text(
} "hyperparameters",
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape) "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
b_logprobs = logprobs[:args.num_steps].reshape(-1) )
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) # seeding
b_returns = b_advantages + b_values random.seed(args.seed)
if args.fix_target: np.random.seed(args.seed)
b_learns = learns[:args.num_steps].reshape(-1) key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else: else:
b_learns = torch.ones_like(b_values, dtype=torch.bool) eval_params = None
# Optimizing the policy and value network @jax.jit
b_inds = np.arange(args.local_batch_size) def get_logits_and_value(
clipfracs = [] params: flax.core.FrozenDict, inputs,
for epoch in range(args.update_epochs): ):
np.random.shuffle(b_inds) rstate, logits, value, valid = create_agent(
for start in range(0, args.local_batch_size, args.local_minibatch_size): args, multi_step=True).apply(params, inputs)
end = start + args.local_minibatch_size return logits, value.squeeze(-1)
mb_inds = b_inds[start:end]
mb_obs = { def ppo_loss(
k: v[mb_inds] for k, v in b_obs.items() params, rstate1, rstate2, obs, dones, next_dones,
} switch, actions, logits, rewards, mask, next_value):
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ # (num_steps * local_num_envs // n_mb))
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], num_envs = next_value.shape[0]
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args) num_steps = dones.shape[0] // num_envs
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm) mask = mask * (1.0 - dones)
scaler.step(optimizer) n_valids = jnp.sum(mask)
scaler.update() real_dones = dones | next_dones
clipfracs.append(clipfrac.item())
inputs = (rstate1, rstate2, obs, real_dones, switch)
if step > 0: new_logits, new_values = get_logits_and_value(params, inputs)
# TODO: use cyclic buffer to avoid copying
for v in obs.values(): new_values_, rewards, next_dones, switch = jax.tree.map(
v[:step] = v[args.num_steps:].clone() lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
for v in [actions, logprobs, rewards, dones, values, learns]: (new_values, rewards, next_dones, switch),
v[:step] = v[args.num_steps:].clone() )
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if args.fix_target: ratios = distrax.importance_sampling_ratios(distrax.Categorical(
if rank == 0: new_logits), distrax.Categorical(logits), actions)
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device) target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
else: else:
should_update = torch.zeros((), dtype=torch.int64, device=device) pg_loss = clipped_surrogate_pg_loss(
if args.world_size > 1: ratios, advantages, args.clip_coef, args.dual_clip_coef)
dist.all_reduce(should_update, op=dist.ReduceOp.SUM) pg_loss = jnp.sum(pg_loss * mask)
should_update = should_update.item() > 0
if should_update: v_loss = mse_loss(new_values, target_values)
agent_t.load_state_dict(agent.state_dict()) v_loss = jnp.sum(v_loss * mask)
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False) ent_loss = entropy_loss(new_logits)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t) ent_loss = jnp.sum(ent_loss * mask)
version += 1 pg_loss = pg_loss / n_valids
if rank == 0: v_loss = v_loss / n_valids
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt")) ent_loss = ent_loss / n_valids
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear() loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
avg_ep_returns.clear() return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
if args.eval_interval and iteration % args.eval_interval == 0: def single_device_update(
# Eval with rule-based policy agent_state: TrainState,
_start = time.time() sharded_storages: List,
eval_return = evaluate( sharded_init_rstate1: List,
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0] sharded_init_rstate2: List,
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device) sharded_next_inputs: List,
sharded_next_main: List,
# sync the statistics key: jax.random.PRNGKey,
if args.world_size > 1: learn_opponent: bool = False,
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) ):
eval_return = eval_stats.cpu().numpy() storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
if rank == 0: next_inputs, init_rstate1, init_rstate2 = [
writer.add_scalar("charts/eval_return", eval_return, global_step) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
if local_rank == 0: for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
eval_time = time.time() - _start ]
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}") next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
# Eval with old model (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
if args.world_size > 1: params_queues = []
dist.destroy_process_group() rollout_queues = []
envs.close() eval_queue = queue.Queue()
if rank == 0: dummy_writer = SimpleNamespace()
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt")) dummy_writer.add_scalar = lambda x, y, z: None
writer.close()
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
f"{global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if __name__ == "__main__": if args.distributed:
main() jax.distributed.shutdown()
writer.close()
\ No newline at end of file
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Optional
import ygoenv
import numpy as np
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup('nccl', local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.local_num_envs, local_env_threads)
obs_space = envs.env.observation_space
action_shape = envs.env.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval()
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
if args.fix_target:
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def predict_step(agent: Agent, next_obs):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid = agent(next_obs)
return logits, value
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.fix_target:
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
else:
traced_model_t = traced_model
train_step = torch.compile(train_step_, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
train_step = train_step_
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device)
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
main_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player_)
main_player = to_tensor(main_player_, device, dtype=next_to_play.dtype)
step = 0
for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
collect_start = time.time()
while step < args.collect_length:
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == main_player
learns[step] = learn
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
if args.fix_target:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start
rewards[step] = to_tensor(reward, device)
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool)
step += 1
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == main_player_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
step = args.collect_length - args.num_steps
_start = time.time()
# bootstrap value if not done
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == main_player, value, -value)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != main_player, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps):
v_end = min(v_start + v_steps, step)
v_obs = {
k: v[v_start:v_end].flatten(0, 1) for k, v in obs.items()
}
with torch.no_grad():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value
advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - _start
_start = time.time()
# flatten the batch
b_obs = {
k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_actions = actions[:args.num_steps].reshape((-1,) + action_shape)
b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.fix_target:
b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if step > 0:
# TODO: use cyclic buffer to avoid copying
for v in obs.values():
v[:step] = v[args.num_steps:].clone()
for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if args.fix_target:
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment