Commit 2bf8ce6a authored by sbl1996@126.com's avatar sbl1996@126.com

Add jax and lstm

parent 096e743e
...@@ -14,18 +14,12 @@ import tyro ...@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass @dataclass
class Args: class Args:
seed: int = 1 seed: int = 1
"""the random seed""" """the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
...@@ -60,37 +54,43 @@ class Args: ...@@ -60,37 +54,43 @@ class Args:
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint1: Optional[str] = "checkpoints/agent.pt" checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent""" """the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: Optional[str] = "checkpoints/agent.pt" checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent""" """the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False compile: bool = False
"""if toggled, the model will be compiled""" """if toggled, the model will be compiled"""
optimize: bool = False optimize: bool = False
"""if toggled, the model will be optimized""" """if toggled, the model will be optimized"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def predict_step(agent, obs):
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
return probs
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
if args.record: if args.record:
assert args.num_envs == 1, "Recording only works with a single environment" assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode" assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file) deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
...@@ -101,14 +101,20 @@ if __name__ == "__main__": ...@@ -101,14 +101,20 @@ if __name__ == "__main__":
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
import torch if args.framework is None:
torch.manual_seed(args.seed) args.framework = "jax" if "flax_model" in args.checkpoint1 else "torch"
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.set_num_threads(args.torch_threads) if args.framework == "torch":
torch.set_float32_matmul_precision('high') import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
else:
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs num_envs = args.num_envs
...@@ -124,36 +130,48 @@ if __name__ == "__main__": ...@@ -124,36 +130,48 @@ if __name__ == "__main__":
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode='self', play_mode='self',
async_reset=False,
verbose=args.verbose, verbose=args.verbose,
record=args.record, record=args.record,
) )
obs_space = envs.observation_space
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.checkpoint1.endswith(".ptj"): if args.framework == 'torch':
agent1 = torch.jit.load(args.checkpoint1) from ygoai.rl.agent import PPOAgent as Agent
agent2 = torch.jit.load(args.checkpoint2) from ygoai.rl.buffer import create_obs
else:
embedding_shape = args.num_embeddings device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if embedding_shape is None:
with open(args.code_list_file, "r") as f: if args.checkpoint1.endswith(".ptj"):
code_list = f.readlines() agent1 = torch.jit.load(args.checkpoint1)
embedding_shape = len(code_list) agent2 = torch.jit.load(args.checkpoint2)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
else: else:
if args.optimize: # count lines of code_list
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
def get_probs(agent, obs):
with torch.no_grad():
return torch.softmax(agent(obs)[0], dim=-1)
if args.compile:
get_probs = torch.compile(get_probs, mode='reduce-overhead')
elif args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device) obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent): def optimize_for_inference(agent):
with torch.no_grad(): with torch.no_grad():
...@@ -161,9 +179,58 @@ if __name__ == "__main__": ...@@ -161,9 +179,58 @@ if __name__ == "__main__":
return torch.jit.optimize_for_inference(traced_model) return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1) agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2) agent2 = optimize_for_inference(agent2)
def predict_fn(agent, obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs = get_probs(agent, obs)
probs = probs.cpu().numpy()
return probs
predict_fn1 = lambda obs: predict_fn(agent1, obs)
predict_fn2 = lambda obs: predict_fn(agent2, obs)
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
print(jax.tree.leaves(params)[0].devices())
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(params, obs):
probs = get_probs(params, obs)
return np.array(probs)
predict_fn1 = lambda obs: predict_fn(params1, obs)
predict_fn2 = lambda obs: predict_fn(params2, obs)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play_ = infos['to_play'] next_to_play = infos['to_play']
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -174,12 +241,10 @@ if __name__ == "__main__": ...@@ -174,12 +241,10 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
num_envs_half = num_envs // 2 player1 = np.concatenate([
player1_ = np.concatenate([ np.zeros(num_envs // 2, dtype=np.int64),
np.zeros(num_envs_half, dtype=np.int64), np.ones(num_envs - num_envs // 2, dtype=np.int64)
np.ones(num_envs - num_envs_half, dtype=np.int64)
]) ])
player1 = torch.from_numpy(player1_).to(device=device)
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
...@@ -189,21 +254,24 @@ if __name__ == "__main__": ...@@ -189,21 +254,24 @@ if __name__ == "__main__":
model_time = env_time = 0 model_time = env_time = 0
_start = time.time() _start = time.time()
next_to_play = torch.from_numpy(next_to_play_).to(device=device) if args.num_envs != 1:
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) probs1 = predict_fn1(obs)
probs1 = predict_step(agent1, obs).clone() probs2 = predict_fn2(obs)
probs2 = predict_step(agent2, obs).clone() probs = np.where((next_to_play == player1)[:, None], probs1, probs2)
else:
if (next_to_play == player1).all():
probs = predict_fn1(obs)
else:
probs = predict_fn2(obs)
probs = torch.where((next_to_play == player1)[:, None], probs1, probs2)
probs = probs.cpu().numpy()
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
to_play = next_to_play_ to_play = next_to_play
_start = time.time() _start = time.time()
obs, rewards, dones, infos = envs.step(actions) obs, rewards, dones, infos = envs.step(actions)
next_to_play_ = infos['to_play'] next_to_play = infos['to_play']
env_time += time.time() - _start env_time += time.time() - _start
step += 1 step += 1
...@@ -211,11 +279,10 @@ if __name__ == "__main__": ...@@ -211,11 +279,10 @@ if __name__ == "__main__":
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if d: if d:
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == player1_[idx] else -1 pl = 1 if to_play[idx] == player1[idx] else -1
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0)
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(episode_reward) episode_rewards.append(episode_reward)
...@@ -223,8 +290,8 @@ if __name__ == "__main__": ...@@ -223,8 +290,8 @@ if __name__ == "__main__":
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n") sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
player1_ = 1 - player1_
player1 = 1 - player1 player1 = 1 - player1
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
......
...@@ -14,18 +14,12 @@ import tyro ...@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass @dataclass
class Args: class Args:
seed: int = 1 seed: int = 1
"""the random seed""" """the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
"""the id of the environment""" """the id of the environment"""
...@@ -41,7 +35,7 @@ class Args: ...@@ -41,7 +35,7 @@ class Args:
"""the language to use""" """the language to use"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
...@@ -50,8 +44,6 @@ class Args: ...@@ -50,8 +44,6 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player""" """the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False play: bool = False
"""whether to play the game""" """whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
record: bool = False record: bool = False
"""whether to record the game as YGOPro replays""" """whether to record the game as YGOPro replays"""
...@@ -67,27 +59,36 @@ class Args: ...@@ -67,27 +59,36 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy" strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used""" """the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
checkpoint: Optional[str] = "checkpoints/agent.pt" checkpoint: Optional[str] = None
"""the checkpoint to load""" """the checkpoint to load, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False compile: bool = False
"""if toggled, the model will be compiled""" """if toggled, the model will be compiled"""
optimize: bool = True optimize: bool = True
"""if toggled, the model will be optimized""" """if toggled, the model will be optimized"""
convert: bool = False convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit""" """if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16 env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`""" """the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
...@@ -102,7 +103,6 @@ if __name__ == "__main__": ...@@ -102,7 +103,6 @@ if __name__ == "__main__":
os.makedirs("replay") os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file) deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
...@@ -113,15 +113,20 @@ if __name__ == "__main__": ...@@ -113,15 +113,20 @@ if __name__ == "__main__":
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
if args.agent: if args.checkpoint and args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint else "torch"
if args.framework == "torch":
import torch import torch
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads) torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high') torch.set_float32_matmul_precision('high')
else:
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs num_envs = args.num_envs
...@@ -136,15 +141,22 @@ if __name__ == "__main__": ...@@ -136,15 +141,22 @@ if __name__ == "__main__":
player=args.player, player=args.player,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")), play_mode='human' if args.play else ('bot' if args.bot_type == "greedy" else "random"),
async_reset=False,
verbose=args.verbose, verbose=args.verbose,
record=args.record, record=args.record,
) )
obs_space = envs.observation_space
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
if args.agent: if args.framework == 'torch':
if args.checkpoint and args.checkpoint.endswith(".ptj"): from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint) agent = torch.jit.load(args.checkpoint)
else: else:
# count lines of code_list # count lines of code_list
...@@ -155,12 +167,11 @@ if __name__ == "__main__": ...@@ -155,12 +167,11 @@ if __name__ == "__main__":
embedding_shape = len(code_list) embedding_shape = len(code_list)
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint: state_dict = torch.load(args.checkpoint, map_location=device)
state_dict = torch.load(args.checkpoint, map_location=device) if not args.compile:
if not args.compile: prefix = "_orig_mod."
prefix = "_orig_mod." state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} print(agent.load_state_dict(state_dict))
print(agent.load_state_dict(state_dict))
if args.compile: if args.compile:
if args.convert: if args.convert:
...@@ -191,6 +202,48 @@ if __name__ == "__main__": ...@@ -191,6 +202,48 @@ if __name__ == "__main__":
print(f"Optimized model saved to {args.checkpoint}j") print(f"Optimized model saved to {args.checkpoint}j")
exit(0) exit(0)
def predict_fn(obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits = agent(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
return probs
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(obs):
probs = get_probs(params, obs)
return np.array(probs)
print(f"loaded checkpoint from {args.checkpoint}")
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
...@@ -210,16 +263,11 @@ if __name__ == "__main__": ...@@ -210,16 +263,11 @@ if __name__ == "__main__":
start_step = step start_step = step
model_time = env_time = 0 model_time = env_time = 0
if args.agent: if args.framework:
_start = time.time() _start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) probs = predict_fn(obs)
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
if args.verbose: if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()]) print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
print(f"{values[0].item():.4f}")
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
else: else:
...@@ -228,13 +276,6 @@ if __name__ == "__main__": ...@@ -228,13 +276,6 @@ if __name__ == "__main__":
else: else:
actions = np.zeros(num_envs, dtype=np.int32) actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play = next_to_play to_play = next_to_play
_start = time.time() _start = time.time()
...@@ -249,15 +290,7 @@ if __name__ == "__main__": ...@@ -249,15 +290,7 @@ if __name__ == "__main__":
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] episode_reward = infos['r'][idx]
if args.selfplay: win = int(episode_reward > 0)
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward < 0:
win = 0
else:
win = 1
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(episode_reward) episode_rewards.append(episode_reward)
......
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100
"""the frequency of saving the model"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 64
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 20
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
gradient_accumulation_steps: int = 1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [1])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logitss: list
rewards: list
learns: list
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def apply_fn(
params: flax.core.FrozenDict,
next_obs,
):
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
def get_action(
params: flax.core.FrozenDict,
next_obs,
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs,
key: jax.random.PRNGKey,
):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0]
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
return next_obs, action, logits, key
# put data in the last index
envs.async_reset()
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
storage = []
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
next_to_play = None
learn = np.ones(args.local_num_envs, dtype=np.bool_)
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
num_steps_with_bootstrap = (
args.num_steps + int(len(storage) == 0)
) # num_steps + 1 to get the states for value bootstrapping.
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for _ in range(0, num_steps_with_bootstrap):
global_step += args.local_num_envs * n_actors * args.world_size
_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
next_reward = np.where(learn, next_reward, -next_reward)
env_time += time.time() - _start
to_play = next_to_play
next_to_play = info["to_play"]
learn = next_to_play == ai_player1
inference_time_start = time.time()
next_obs, action, logits, key = sample_action(params, next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
envs.send(cpu_action)
storage.append(
Transition(
obs=next_obs,
dones=next_done,
actions=action,
logitss=logits,
rewards=next_reward,
learns=learn,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
# move bootstrapping step to the beginning of the next update
storage = storage[-1:]
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns) if len(avg_ep_returns) > 0 else 0
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(
args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(
args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) *
args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(
len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * \
args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (
args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [
str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % (
"\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args)
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=args.gradient_accumulation_steps,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
agent_state = flax.jax_utils.replicate(
agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict,
obs: np.ndarray,
):
logits, value, valid = create_agent(args).apply(params, obs)
return logits, value.squeeze(-1), valid
def impala_loss(params, obs, actions, logitss, rewards, dones, learns):
# (num_steps + 1, local_num_envs // n_mb))
discounts = (1.0 - dones) * args.gamma
policy_logits, newvalue, valid = jax.vmap(
get_logits_and_value, in_axes=(None, 0))(params, obs)
newvalue = jnp.where(learns, newvalue, -newvalue)
v_t = newvalue[1:]
# Remove bootstrap timestep from non-timesteps.
v_tm1 = newvalue[:-1]
policy_logits = policy_logits[:-1]
logitss = logitss[:-1]
actions = actions[:-1]
mask = 1.0 - dones
rewards = rewards[1:]
discounts = discounts[1:]
mask = mask[:-1]
rhos = rlax.categorical_importance_sampling_ratios(
policy_logits, logitss, actions)
vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, rhos)
jax.debug.print("R {}", jnp.where(dones[1:-1, :2], rewards[:-1, :2], 0).T)
jax.debug.print("E {}", jnp.where(dones[1:-1, :2], vtrace_returns.errors[:-1, :2] * 100, vtrace_returns.errors[:-1, :2]).T)
jax.debug.print("V {}", v_tm1[:-1, :2].T)
T = v_tm1.shape[0]
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
else:
advs = vtrace_returns.q_estimate - v_tm1
if args.ppo_clip:
pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
rhos, advs, mask) * T
pg_loss = jnp.sum(pg_loss)
else:
pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs
pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)(
policy_logits, actions, pg_advs, mask) * T
pg_loss = jnp.sum(pg_loss)
baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask)
ent_loss = jax.vmap(rlax.entropy_loss, in_axes=1)(
policy_logits, mask) * T
ent_loss = jnp.sum(ent_loss)
n_samples = jnp.sum(mask)
pg_loss = pg_loss / n_samples
baseline_loss = baseline_loss / n_samples
ent_loss = ent_loss / n_samples
total_loss = pg_loss
total_loss += args.vf_coef * baseline_loss
total_loss += args.ent_coef * ent_loss
return total_loss, (pg_loss, baseline_loss, ent_loss)
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List[Transition],
key: jax.random.PRNGKey,
):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True)
def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_logitss, mb_rewards, mb_dones, mb_learns = minibatch
(loss, (pg_loss, v_loss, entropy_loss)), grads = impala_loss_grad_fn(
agent_state.params,
mb_obs,
mb_actions,
mb_logitss,
mb_rewards,
mb_dones,
mb_learns,
)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss)
n_mb = args.num_minibatches * args.gradient_accumulation_steps
storage_obs = {
k: jnp.array(jnp.split(v, n_mb, axis=1))
for k, v in storage.obs.items()
}
agent_state, (loss, pg_loss, v_loss, entropy_loss) = jax.lax.scan(
update_minibatch,
agent_state,
(
# (num_steps + 1, local_num_envs) => (n_mb, num_steps + 1, local_num_envs // n_mb)
storage_obs,
jnp.array(jnp.split(storage.actions, n_mb, axis=1)),
jnp.array(jnp.split(storage.logitss, n_mb, axis=1)),
jnp.array(jnp.split(storage.rewards, n_mb, axis=1)),
jnp.array(jnp.split(storage.dones, n_mb, axis=1)),
jnp.array(jnp.split(storage.learns, n_mb, axis=1)),
),
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(
entropy_loss, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(
unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
rollout_queue_get_time.append(
time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(
unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads +
thread_id].put(device_params)
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time",
np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time",
time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size",
rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size",
params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss",
v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss",
pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy",
entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints/{run_name}"
os.makedirs(ckpt_dir, exist_ok=True)
model_path = ckpt_dir + "/agent.cleanrl_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
vars(args),
unreplicated_params,
]
)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
...@@ -38,8 +38,10 @@ class Args: ...@@ -38,8 +38,10 @@ class Args:
"""seed of the experiment""" """seed of the experiment"""
log_frequency: int = 10 log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)""" """the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100 save_interval: int = 400
"""the frequency of saving the model""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
...@@ -89,6 +91,8 @@ class Args: ...@@ -89,6 +91,8 @@ class Args:
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.25 clip_coef: float = 0.25
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
...@@ -101,9 +105,9 @@ class Args: ...@@ -101,9 +105,9 @@ class Args:
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0]) actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use""" """the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [1]) learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use""" """the device ids that learner workers will use"""
distributed: bool = False distributed: bool = False
"""whether to use `jax.distirbuted`""" """whether to use `jax.distirbuted`"""
...@@ -122,7 +126,6 @@ class Args: ...@@ -122,7 +126,6 @@ class Args:
# runtime arguments to be filled in # runtime arguments to be filled in
local_batch_size: int = 0 local_batch_size: int = 0
local_minibatch_size: int = 0 local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0 world_size: int = 0
local_rank: int = 0 local_rank: int = 0
num_envs: int = 0 num_envs: int = 0
...@@ -165,6 +168,7 @@ class Transition(NamedTuple): ...@@ -165,6 +168,7 @@ class Transition(NamedTuple):
logprobs: list logprobs: list
rewards: list rewards: list
learns: list learns: list
probs: list
def create_agent(args): def create_agent(args):
...@@ -231,7 +235,7 @@ def rollout( ...@@ -231,7 +235,7 @@ def rollout(
next_obs, next_obs,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
): ):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = get_logits(params, next_obs) logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick # sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
...@@ -239,7 +243,11 @@ def rollout( ...@@ -239,7 +243,11 @@ def rollout(
u = jax.random.uniform(subkey, shape=logits.shape) u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
return next_obs, action, logprob, key
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, action, logprob, probs, key
# put data in the last index # put data in the last index
params_queue_get_time = deque(maxlen=10) params_queue_get_time = deque(maxlen=10)
...@@ -258,7 +266,7 @@ def rollout( ...@@ -258,7 +266,7 @@ def rollout(
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree_map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage) return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2): for update in range(1, args.num_updates + 2):
if update == 10: if update == 10:
...@@ -290,7 +298,7 @@ def rollout( ...@@ -290,7 +298,7 @@ def rollout(
learn = next_to_play == ai_player1 learn = next_to_play == ai_player1
inference_time_start = time.time() inference_time_start = time.time()
cached_next_obs, action, logprob, key = sample_action(params, cached_next_obs, key) cached_next_obs, action, logprob, probs, key = sample_action(params, cached_next_obs, key)
cpu_action = np.array(action) cpu_action = np.array(action)
inference_time += time.time() - inference_time_start inference_time += time.time() - inference_time_start
...@@ -308,6 +316,7 @@ def rollout( ...@@ -308,6 +316,7 @@ def rollout(
logprobs=logprob, logprobs=logprob,
rewards=next_reward, rewards=next_reward,
learns=learn, learns=learn,
probs=probs,
) )
) )
...@@ -338,7 +347,7 @@ def rollout( ...@@ -338,7 +347,7 @@ def rollout(
sharded_storage.append(x) sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage) sharded_storage = Transition(*sharded_storage)
next_learn = ai_player1 == next_to_play next_learn = ai_player1 == next_to_play
sharded_data = jax.tree_map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_learn)) (next_obs, next_done, next_learn))
payload = ( payload = (
...@@ -460,7 +469,7 @@ if __name__ == "__main__": ...@@ -460,7 +469,7 @@ if __name__ == "__main__":
obs_space = envs.observation_space obs_space = envs.observation_space
action_shape = envs.action_space.shape action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}") print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close() envs.close()
del envs del envs
...@@ -486,6 +495,11 @@ if __name__ == "__main__": ...@@ -486,6 +495,11 @@ if __name__ == "__main__":
params=params, params=params,
tx=tx, tx=tx,
) )
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
...@@ -498,14 +512,16 @@ if __name__ == "__main__": ...@@ -498,14 +512,16 @@ if __name__ == "__main__":
): ):
logits, value, valid = create_agent(args).apply(params, obs) logits, value, valid = create_agent(args).apply(params, obs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions] logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min) logits = logits.clip(min=jnp.finfo(logits.dtype).min)
p_log_p = logits * jax.nn.softmax(logits) probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1) entropy = -p_log_p.sum(-1)
return logprob, entropy, value.squeeze(), valid return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(params, obs, actions, logprobs, advantages, target_values): def ppo_loss(params, obs, actions, logprobs, probs, advantages, target_values):
newlogprob, entropy, newvalue, valid = get_logprob_entropy_value(params, obs, actions) newlogprob, newprobs, entropy, newvalue, valid = get_logprob_entropy_value(params, obs, actions)
logratio = newlogprob - logprobs logratio = newlogprob - logprobs
ratio = jnp.exp(logratio) ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean() approx_kl = ((ratio - 1) - logratio).mean()
...@@ -514,9 +530,20 @@ if __name__ == "__main__": ...@@ -514,9 +530,20 @@ if __name__ == "__main__":
advantages = masked_normalize(advantages, valid, eps=1e-8) advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss # Policy loss
pg_loss1 = -advantages * ratio if args.spo_kld_max is not None:
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) eps = 1e-8
pg_loss = jnp.maximum(pg_loss1, pg_loss2) kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid) pg_loss = masked_mean(pg_loss, valid)
# Value loss # Value loss
...@@ -539,11 +566,15 @@ if __name__ == "__main__": ...@@ -539,11 +566,15 @@ if __name__ == "__main__":
def flatten(x): def flatten(x):
return x.reshape((-1,) + x.shape[2:]) return x.reshape((-1,) + x.shape[2:])
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree_map(lambda *x: jnp.concatenate(x), *sharded_next_obs) next_obs = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_done, next_learn = [ next_done, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_learn] jnp.concatenate(x) for x in [sharded_next_done, sharded_next_learn]
] ]
print(jax.tree_map(lambda x: x.shape, storage))
print(jax.tree_map(lambda x: x.shape, next_obs))
print(next_done.shape, next_learn.shape)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _): def update_epoch(carry, _):
...@@ -554,13 +585,13 @@ if __name__ == "__main__": ...@@ -554,13 +585,13 @@ if __name__ == "__main__":
values = create_agent(args).apply(agent_state.params, mb_obs)[1].squeeze(-1) values = create_agent(args).apply(agent_state.params, mb_obs)[1].squeeze(-1)
return agent_state, values return agent_state, values
flatten_obs = jax.tree_map(lambda x: x.reshape((-1, args.local_minibatch_size * 8) + x.shape[2:]), storage.obs) flatten_obs = jax.tree.map(lambda x: x.reshape((-1, args.local_minibatch_size * 8) + x.shape[2:]), storage.obs)
_, values = jax.lax.scan( _, values = jax.lax.scan(
get_value_minibatch, agent_state, flatten_obs) get_value_minibatch, agent_state, flatten_obs)
values = values.reshape(storage.rewards.shape) values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(agent_state.params, next_obs)[1].squeeze(-1) next_value = create_agent(args).apply(agent_state.params, next_obs)[1].squeeze(-1)
compute_gae_fn = compute_gae_upgo if args.upgo else compute_gae compute_gae_fn = compute_gae_upgo if args.upgo else compute_gae
advantages, target_values = compute_gae_fn( advantages, target_values = compute_gae_fn(
next_value, next_done, next_learn, next_value, next_done, next_learn,
...@@ -574,20 +605,21 @@ if __name__ == "__main__": ...@@ -574,20 +605,21 @@ if __name__ == "__main__":
x = jnp.reshape(x, (-1, args.local_minibatch_size) + x.shape[1:]) x = jnp.reshape(x, (-1, args.local_minibatch_size) + x.shape[1:])
return x return x
flatten_storage = jax.tree_map(flatten, jax.tree_map(lambda x: x[:args.num_steps], storage)) flatten_storage = jax.tree.map(flatten, jax.tree.map(lambda x: x[:args.num_steps], storage))
flatten_advantages = flatten(advantages) flatten_advantages = flatten(advantages)
flatten_target_values = flatten(target_values) flatten_target_values = flatten(target_values)
shuffled_storage = jax.tree_map(convert_data, flatten_storage) shuffled_storage = jax.tree.map(convert_data, flatten_storage)
shuffled_advantages = convert_data(flatten_advantages) shuffled_advantages = convert_data(flatten_advantages)
shuffled_target_values = convert_data(flatten_target_values) shuffled_target_values = convert_data(flatten_target_values)
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_target_values = minibatch mb_obs, mb_actions, mb_logprobs, mb_probs, mb_advantages, mb_target_values = minibatch
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, agent_state.params,
mb_obs, mb_obs,
mb_actions, mb_actions,
mb_logprobs, mb_logprobs,
mb_probs,
mb_advantages, mb_advantages,
mb_target_values, mb_target_values,
) )
...@@ -602,6 +634,7 @@ if __name__ == "__main__": ...@@ -602,6 +634,7 @@ if __name__ == "__main__":
shuffled_storage.obs, shuffled_storage.obs,
shuffled_storage.actions, shuffled_storage.actions,
shuffled_storage.logprobs, shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages, shuffled_advantages,
shuffled_target_values, shuffled_target_values,
), ),
...@@ -720,17 +753,13 @@ if __name__ == "__main__": ...@@ -720,17 +753,13 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss[-1].item(), global_step) writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints/{run_name}" ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True) os.makedirs(ckpt_dir, exist_ok=True)
model_path = ckpt_dir + "/agent.flax_model" M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f: with open(model_path, "wb") as f:
f.write( f.write(
flax.serialization.to_bytes( flax.serialization.to_bytes(unreplicated_params)
[
vars(args),
unreplicated_params,
]
)
) )
print(f"model saved to {model_path}") print(f"model saved to {model_path}")
......
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
rewards: list
learns: list
probs: list
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, next_obs):
return create_agent(args).apply(params, next_obs)[0]
def get_action(
params: flax.core.FrozenDict, next_obs):
return get_logits(params, next_obs).argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, key: jax.random.PRNGKey):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, action, logprob, probs, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
learn = next_to_play == ai_player1
inference_time_start = time.time()
cached_next_obs, action, logprob, probs, key = sample_action(
params, cached_next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logprobs=logprob,
rewards=next_reward,
learns=learn,
probs=probs,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_learn = learn[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.dones[idx]:
# For OTK where player may not switch
break
if t.learns[idx] != cur_learn:
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_learn = ai_player1 == next_to_play
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_learn))
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
agent = create_agent(args)
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict, obs, actions,
):
logits, value, valid = create_agent(args).apply(params, obs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, obs, actions, logprobs, probs, advantages, target_values):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, obs, actions)
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_done: List,
sharded_next_learn: List,
key: jax.random.PRNGKey,
):
def flatten(x):
return x.reshape((-1,) + x.shape[2:])
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_done, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_learn]
]
# reorder storage of individual players
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
learns = (storage.learns == next_learn).astype(jnp.int32)
indices = jnp.argsort(T[:, None] + learns * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(learns, axis=0))
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
flatten_obs = jax.tree.map(lambda x: x.reshape((-1, args.local_minibatch_size * 8) + x.shape[2:]), storage.obs)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_obs):
values = create_agent(args).apply(
agent_state.params, mb_obs)[1].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, flatten_obs)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, next_obs)[1].squeeze(-1)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray):
x = jax.random.permutation(subkey, x)
x = jnp.reshape(x, (-1, args.local_minibatch_size) + x.shape[1:])
return x
flatten_storage = jax.tree.map(flatten, jax.tree.map(lambda x: x[:args.num_steps], storage))
flatten_advantages = flatten(advantages)
flatten_target_values = flatten(target_values)
shuffled_storage, shuffled_advantages, shuffled_target_values = jax.tree.map(
convert_data, (flatten_storage, flatten_advantages, flatten_target_values))
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_storage.obs,
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
sharded_next_obss = []
sharded_next_dones = []
sharded_next_learns = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
sharded_next_obs,
sharded_next_done,
sharded_next_learn,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
sharded_next_obss.append(sharded_next_obs)
sharded_next_dones.append(sharded_next_done)
sharded_next_learns.append(sharded_next_learn)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
sharded_next_obss,
sharded_next_dones,
sharded_next_learns,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo2, compute_gae2
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
lstm_channels: int = 512
"""the number of channels for the LSTM in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
rewards: list
learns: list
probs: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.lstm_channels,
multi_step=multi_step,
)
def init_carry(num_envs, lstm_channels):
return (
np.zeros((num_envs, lstm_channels)),
np.zeros((num_envs, lstm_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
carry, logits = create_agent(args).apply(params, inputs)[:2]
carry = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), carry)
return carry, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
carry, logits = get_logits(params, inputs, done)
return carry, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, carry1, carry2, learn, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
learn = jnp.array(learn)
carry = jax.tree.map(
lambda x1, x2: jnp.where(learn[:, None], x1, x2), carry1, carry2)
carry, logits = get_logits(params, (carry, next_obs), done)
carry1 = jax.tree.map(lambda x, y: jnp.where(learn[:, None], x, y), carry, carry1)
carry2 = jax.tree.map(lambda x, y: jnp.where(learn[:, None], y, x), carry, carry2)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, carry1, carry2, action, logprob, probs, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_lstm_state1 = next_lstm_state2 = init_carry(
args.local_num_envs, args.lstm_channels)
eval_rnn_state = init_carry(
args.local_eval_episodes, args.lstm_channels)
ai_player1 = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
initial_lstm_state1, initial_lstm_state2 = jax.tree.map(
lambda x: x.copy(), (next_lstm_state1, next_lstm_state2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
learn = next_to_play == ai_player1
inference_time_start = time.time()
cached_next_obs, next_lstm_state1, next_lstm_state2, action, logprob, probs, key = sample_action(
params, cached_next_obs, next_lstm_state1, next_lstm_state2, learn, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logprobs=logprob,
rewards=next_reward,
learns=learn,
probs=probs,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_learn = learn[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.dones[idx]:
# For OTK where player may not switch
break
if t.learns[idx] != cur_learn:
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_learn = ai_player1 == next_to_play
next_lstm_state = jax.tree.map(
lambda x1, x2: jnp.where(next_learn[:, None], x1, x2), next_lstm_state1, next_lstm_state2)
carry1 = jax.tree.map(
lambda x, y: jnp.where(next_learn[:, None], x, y), initial_lstm_state1, initial_lstm_state2)
carry2 = jax.tree.map(
lambda x, y: jnp.where(next_learn[:, None], y, x), initial_lstm_state1, initial_lstm_state2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_lstm_state, carry1, carry2, next_done, next_learn))
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rnn_state)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
carry = init_carry(1, args.lstm_channels)
agent = create_agent(args)
params = agent.init(agent_key, (carry, sample_obs))
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict, inputs, actions,
):
_carry, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions)
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_carry: List,
sharded_carry1: List,
sharded_carry2: List,
sharded_next_done: List,
sharded_next_learn: List,
key: jax.random.PRNGKey,
):
def reshape_minibatch(x, num_minibatches, num_steps=1):
N = num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_carry = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_carry)
carry1 = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_carry1)
carry2 = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_carry2)
next_done, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_learn]
]
# reorder storage of individual players
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
learns = (storage.learns == next_learn).astype(jnp.int32)
indices = jnp.argsort(T[:, None] + learns * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(learns, axis=0))
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
# split minibatches for recompute values
n_mbs = args.num_minibatches // 4
flatten_carry = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs),
(carry1, carry2))
flatten_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, num_steps=args.num_steps),
(storage.obs, storage.dones, switch))
flatten_inputs = flatten_carry + flatten_inputs
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_inputs):
values = create_agent(args, multi_step=True).apply(
agent_state.params, mb_inputs)[2].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, flatten_inputs)
values = values.reshape((n_mbs, args.num_steps, -1)).transpose(1, 0, 2)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, (next_carry, next_obs))[2].squeeze(-1)
compute_gae_fn = compute_gae_upgo2 if args.upgo else compute_gae2
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray, num_steps=1):
x = jax.random.permutation(subkey, x, axis=1)
return reshape_minibatch(x, args.num_minibatches, num_steps)
shuffled_carry1, shuffled_carry2 = jax.tree.map(
partial(convert_data, num_steps=1), (carry1, carry2))
shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch, advantages, target_values))
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
(
shuffled_carry1,
shuffled_carry2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_switch,
),
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
sharded_next_obss = []
sharded_next_carries = []
sharded_carries1 = []
sharded_carries2 = []
sharded_next_dones = []
sharded_next_learns = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
sharded_next_obs,
sharded_next_carry,
sharded_carry1,
sharded_carry2,
sharded_next_done,
sharded_next_learn,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
sharded_next_obss.append(sharded_next_obs)
sharded_next_carries.append(sharded_next_carry)
sharded_carries1.append(sharded_carry1)
sharded_carries2.append(sharded_carry2)
sharded_next_dones.append(sharded_next_done)
sharded_next_learns.append(sharded_next_learn)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
sharded_next_obss,
sharded_next_carries,
sharded_carries1,
sharded_carries2,
sharded_next_dones,
sharded_next_learns,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
...@@ -74,9 +74,16 @@ class Args: ...@@ -74,9 +74,16 @@ class Args:
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.98 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256 minibatch_size: int = 256
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
...@@ -93,8 +100,6 @@ class Args: ...@@ -93,8 +100,6 @@ class Args:
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
...@@ -169,6 +174,7 @@ def main(): ...@@ -169,6 +174,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size) args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps) args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps args.collect_length = args.collect_length or args.num_steps
...@@ -247,6 +253,13 @@ def main(): ...@@ -247,6 +253,13 @@ def main():
if args.embedding_file: if args.embedding_file:
agent.freeze_embeddings() agent.freeze_embeddings()
if args.fix_target:
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
optim_params = list(agent.parameters()) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
...@@ -265,10 +278,16 @@ def main(): ...@@ -265,10 +278,16 @@ def main():
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device) example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.fix_target:
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
else:
traced_model_t = traced_model
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step, mode=args.compile)
else: else:
traced_model = agent traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device) obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
...@@ -280,6 +299,7 @@ def main(): ...@@ -280,6 +299,7 @@ def main():
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device) learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game # TRY NOT TO MODIFY: start the game
global_step = 0 global_step = 0
...@@ -296,7 +316,6 @@ def main(): ...@@ -296,7 +316,6 @@ def main():
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0 step = 0
for iteration in range(args.num_iterations): for iteration in range(args.num_iterations):
...@@ -320,6 +339,10 @@ def main(): ...@@ -320,6 +339,10 @@ def main():
_start = time.time() _start = time.time()
logits, value = predict_step(traced_model, next_obs) logits, value = predict_step(traced_model, next_obs)
if args.fix_target:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) probs = Categorical(logits=logits)
action = probs.sample() action = probs.sample()
...@@ -331,10 +354,6 @@ def main(): ...@@ -331,10 +354,6 @@ def main():
action = action.cpu().numpy() action = action.cpu().numpy()
model_time += time.time() - _start model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time() _start = time.time()
to_play = next_to_play_ to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action) next_obs, reward, next_done_, info = envs.step(action)
...@@ -378,8 +397,12 @@ def main(): ...@@ -378,8 +397,12 @@ def main():
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1) value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == ai_player1, value, -value)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != ai_player1, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
...@@ -409,10 +432,10 @@ def main(): ...@@ -409,10 +432,10 @@ def main():
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.learn_opponent: if args.fix_target:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1) b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size) b_inds = np.arange(args.local_batch_size)
...@@ -476,6 +499,28 @@ def main(): ...@@ -476,6 +499,28 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if args.fix_target:
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
......
import os
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Literal, Optional
import ygoenv
import numpy as np
import optree
import tyro
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, load_embeddings
from ygoai.rl.agent2 import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = None
"""the checkpoint to load the model from"""
total_timesteps: int = 2000000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train: bool = False
"""if toggled, training will be done in fp16 precision"""
fp16_eval: bool = False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir: str = "./runs"
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 1.0
"""the probability of logging"""
eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_minibatches: int = 0
"""the number of mini-batches (computed in runtime)"""
def main():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args)
args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
local_torch_threads = args.torch_threads // args.world_size
local_env_threads = args.env_threads // args.world_size
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1:
torchrun_setup(args.backend, local_rank)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
if rank == 0:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(os.path.join(args.tb_dir, run_name))
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - rank)
if args.torch_deterministic:
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu")
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
else:
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint:
agent.load_state_dict(torch.load(args.checkpoint, map_location=device))
fprint(f"Loaded checkpoint from {args.checkpoint}")
elif args.embedding_file:
agent.load_embeddings(embeddings)
fprint(f"Loaded embeddings from {args.embedding_file}")
if args.embedding_file:
agent.freeze_embeddings()
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def masked_mean(x, valid):
x = x.masked_fill(~valid, 0)
return x.sum() / valid.float().sum()
def masked_normalize(x, valid, eps=1e-8):
x = x.masked_fill(~valid, 0)
n = valid.float().sum()
mean = x.sum() / n
var = ((x - mean) ** 2).sum() / n
std = (var + eps).sqrt()
return (x - mean) / std
def train_step(agent: Agent, scaler, mb_obs, lstm_state, mb_dones, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid, _ = agent(mb_obs, lstm_state, mb_dones)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def predict_step(agent: Agent, next_obs, next_lstm_state, next_done):
with torch.no_grad():
with autocast(enabled=args.fp16_eval):
logits, value, valid, next_lstm_state = agent(next_obs, next_lstm_state, next_done)
return logits, value, next_lstm_state
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs, next_lstm_state, next_done), check_tolerance=False, check_trace=False)
train_step = torch.compile(train_step, mode=args.compile)
def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.num_steps, args.local_num_envs), device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + action_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
learns = torch.zeros((args.num_steps, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
# TRY NOT TO MODIFY: start the game
global_step = 0
warmup_steps = 0
start_time = time.time()
next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, dtype=torch.uint8)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
next_lstm_state = (
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, args.local_num_envs, agent.lstm.hidden_size, device=device),
)
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, dtype=next_to_play.dtype)
next_value1 = 0
next_value2 = 0
for iteration in range(1, args.num_iterations + 1):
initial_lstm_state = (next_lstm_state[0].clone(), next_lstm_state[1].clone())
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
model_time = 0
env_time = 0
collect_start = time.time()
agent.eval()
for step in range(0, args.num_steps):
global_step += args.num_envs
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learns[step] = learn
_start = time.time()
logits, value, next_lstm_state = predict_step(traced_model, next_obs, next_lstm_state, next_done)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
logprob = probs.log_prob(action)
values[step] = value
actions[step] = action
logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_)
env_time += time.time() - _start
rewards[step] = to_tensor(reward)
next_obs, next_done = to_tensor(next_obs, torch.uint8), to_tensor(next_done_, torch.bool)
if not writer:
continue
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if random.random() < 1/n:
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start
if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}")
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = traced_model(next_obs, next_lstm_state, next_done)[1].reshape(-1)
advantages = torch.zeros_like(rewards).to(device)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(args.num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# nextvalues1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# delta1 = reward1 + args.gamma * nextvalues1 - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# advantages[t] = lastgaelam1_
# nextvalues1 = values[t]
# lastgaelam1 = lastgaelam_
# else:
# if dones[t+1]:
# reward2 = rewards[t]
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
#
# reward1 = -rewards[t]
# done_used1 = False
# else:
# if not done_used2:
# reward2 = reward2
# nextvalues2 = 0
# lastgaelam2 = 0
# done_used2 = True
# else:
# reward2 = rewards[t]
# reward1 = reward1
# delta2 = reward2 + args.gamma * nextvalues2 - values[t]
# lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
# advantages[t] = lastgaelam2_
# nextvalues2 = values[t]
# lastgaelam2 = lastgaelam_
learn1 = learns[t]
learn2 = ~learn1
if t != args.num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + args.gamma * nextvalues1 - values[t]
delta2 = reward2 + args.gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + args.gamma * args.gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
returns = advantages + values
bootstrap_time = time.time() - _start
_start = time.time()
agent.train()
# flatten the batch
b_obs = {
k: v.reshape((-1,) + v.shape[2:])
for k, v in obs.items()
}
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + action_shape)
b_dones = dones.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_learns = learns.reshape(-1)
# Optimizing the policy and value network
assert args.local_num_envs % args.num_minibatches == 0
envsperbatch = args.local_num_envs // args.num_minibatches # minibatch_size // num_steps
envinds = np.arange(args.local_num_envs)
flatinds = np.arange(args.local_batch_size).reshape(args.num_steps, args.local_num_envs)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(envinds)
for start in range(0, args.local_num_envs, envsperbatch):
end = start + envsperbatch
mbenvinds = envinds[start:end]
mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = train_step(
agent, scaler, mb_obs, (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
b_dones[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds])
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
train_time = time.time() - _start
if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10
if iteration == SPS_warmup_iters:
start_time = time.time()
warmup_steps = global_step
if iteration > SPS_warmup_iters:
if local_rank == 0:
fprint(f"SPS: {SPS}")
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
episode_lengths = []
episode_rewards = []
eval_win_rates = []
e_obs = eval_envs.reset()[0]
e_dones_ = np.zeros(local_eval_num_envs, dtype=np.bool_)
e_next_lstm_state = (
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
torch.zeros(agent.lstm.num_layers, local_eval_num_envs, agent.lstm.hidden_size, device=device),
)
while True:
e_obs = to_tensor(e_obs, dtype=torch.uint8)
e_dones = to_tensor(e_dones_, dtype=torch.bool)
e_logits, _, e_next_lstm_state = predict_step(traced_model, e_obs, e_next_lstm_state, e_dones)
e_probs = torch.softmax(e_logits, dim=-1)
e_probs = e_probs.cpu().numpy()
e_actions = e_probs.argmax(axis=1)
e_obs, e_rewards, e_dones_, e_info = eval_envs.step(e_actions)
for idx, d in enumerate(e_dones_):
if d:
episode_length = e_info['l'][idx]
episode_reward = e_info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= local_eval_episodes:
break
eval_return = np.mean(episode_rewards[:local_eval_episodes])
eval_ep_len = np.mean(episode_lengths[:local_eval_episodes])
eval_win_rate = np.mean(eval_win_rates[:local_eval_episodes])
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step)
writer.add_scalar("charts/eval_ep_len", eval_ep_len, global_step)
writer.add_scalar("charts/eval_win_rate", eval_win_rate, global_step)
if local_rank == 0:
eval_time = time.time() - _start
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}, eval_ep_len={eval_ep_len:.1f}, eval_win_rate={eval_win_rate:.4f}")
# Eval with old model
if args.world_size > 1:
dist.destroy_process_group()
envs.close()
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_final.pt"))
writer.close()
if __name__ == "__main__":
main()
...@@ -69,11 +69,11 @@ class Args: ...@@ -69,11 +69,11 @@ class Args:
"""the number of parallel game environments""" """the number of parallel game environments"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.98 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
update_win_rate: float = 0.55 update_win_rate: float = 0.55
...@@ -103,8 +103,6 @@ class Args: ...@@ -103,8 +103,6 @@ class Args:
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
...@@ -145,6 +143,8 @@ class Args: ...@@ -145,6 +143,8 @@ class Args:
"""the number of iterations (computed in runtime)""" """the number of iterations (computed in runtime)"""
world_size: int = 0 world_size: int = 0
"""the number of processes (computed in runtime)""" """the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'): def make_env(args, num_envs, num_threads, mode='self'):
...@@ -158,7 +158,7 @@ def make_env(args, num_envs, num_threads, mode='self'): ...@@ -158,7 +158,7 @@ def make_env(args, num_envs, num_threads, mode='self'):
deck2=args.deck2, deck2=args.deck2,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
play_mode='self', play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
...@@ -181,6 +181,7 @@ def main(): ...@@ -181,6 +181,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size) args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps) args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps args.collect_length = args.collect_length or args.num_steps
...@@ -473,7 +474,7 @@ def main(): ...@@ -473,7 +474,7 @@ def main():
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.learn_opponent or selfplay: if selfplay:
b_learns = torch.ones_like(b_values, dtype=torch.bool) b_learns = torch.ones_like(b_values, dtype=torch.bool)
else: else:
b_learns = learns[:args.num_steps].reshape(-1) b_learns = learns[:args.num_steps].reshape(-1)
......
...@@ -3,7 +3,7 @@ import random ...@@ -3,7 +3,7 @@ import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional from typing import Optional
import ygoenv import ygoenv
...@@ -11,18 +11,21 @@ import numpy as np ...@@ -11,18 +11,21 @@ import numpy as np
import tyro import tyro
import torch import torch
import torch.nn as nn torch.set_num_threads(2)
import torch.optim as optim import torch.optim as optim
from torch.distributions import Categorical
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint from ygoai.rl.dist import fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs, get_obs_shape
from ygoai.rl.ppo import bootstrap_value_self from ygoai.rl.ppo import bootstrap_value_selfplay_np as bootstrap_value_selfplay
from ygoai.rl.eval import evaluate from ygoai.rl.eval import evaluate
...@@ -52,10 +55,8 @@ class Args: ...@@ -52,10 +55,8 @@ class Args:
"""the embedding file for card embeddings""" """the embedding file for card embeddings"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 16 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
play_mode: str = "bot"
"""the play mode, can be combination of 'bot' (greedy), 'random', like 'bot+random'"""
num_layers: int = 2 num_layers: int = 2
"""the number of layers for the agent""" """the number of layers for the agent"""
...@@ -68,29 +69,33 @@ class Args: ...@@ -68,29 +69,33 @@ class Args:
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 2.5e-4 learning_rate: float = 2.5e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
num_envs: int = 8 local_num_envs: int = 256
"""the number of parallel game environments""" "the number of parallel game environments"
local_env_threads: Optional[int] = None
"the number of threads to use for environment"
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.997 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55 update_win_rate: float = 0.55
"""the required win rate to update the agent""" """the required win rate to update the agent"""
update_return: float = 0.1 update_return: float = 0.1
"""the required return to update the agent""" """the required return to update the agent"""
minibatch_size: int = 256 local_minibatch_size: int = 4096
"""the mini-batch size""" """the mini-batch size"""
update_epochs: int = 2 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = True norm_adv: bool = True
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.1 clip_coef: float = 0.2
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
clip_vloss: bool = True clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper.""" """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
...@@ -98,17 +103,11 @@ class Args: ...@@ -98,17 +103,11 @@ class Args:
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 0.5 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
target_kl: Optional[float] = None
"""the target KL divergence threshold"""
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)""" """the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
backend: Literal["gloo", "nccl", "mpi"] = "nccl"
"""the backend for distributed training"""
compile: Optional[str] = None compile: Optional[str] = None
"""Compile mode of torch.compile, None for no compilation""" """Compile mode of torch.compile, None for no compilation"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
...@@ -130,51 +129,61 @@ class Args: ...@@ -130,51 +129,61 @@ class Args:
"""the probability of logging""" """the probability of logging"""
eval_episodes: int = 128 eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 10 eval_interval: int = 50
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# to be filled in runtime # to be filled in runtime
local_batch_size: int = 0 local_batch_size: int = 0
"""the local batch size in the local rank (computed in runtime)""" minibatch_size: int = 0
local_minibatch_size: int = 0 num_envs: int = 0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs: int = 0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size: int = 0 batch_size: int = 0
"""the batch size (computed in runtime)"""
num_iterations: int = 0 num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
world_size: int = 0 world_size: int = 0
"""the number of processes (computed in runtime)""" num_embeddings: Optional[int] = None
def make_env(args, num_envs, num_threads, mode='self'):
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
return envs
def main(): def _mp_fn(index, world_size):
rank = int(os.environ.get("RANK", 0)) rank = index
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = index
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}") print(f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
args = tyro.cli(Args) args = tyro.cli(Args)
args.world_size = world_size args.world_size = world_size
args.local_num_envs = args.num_envs // args.world_size args.num_envs = args.local_num_envs * args.world_size
args.local_batch_size = int(args.local_num_envs * args.num_steps) args.local_batch_size = args.local_num_envs * args.num_steps
args.local_minibatch_size = int(args.minibatch_size // args.world_size) args.minibatch_size = args.local_minibatch_size * args.world_size
args.batch_size = int(args.num_envs * args.num_steps) args.batch_size = args.num_envs * args.num_steps
args.num_iterations = args.total_timesteps // args.batch_size args.num_iterations = args.total_timesteps // args.batch_size
args.env_threads = args.env_threads or args.num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
args.env_threads = args.local_env_threads * args.world_size
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size) args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps args.collect_length = args.collect_length or args.num_steps
assert args.local_batch_size % args.local_minibatch_size == 0, "local_batch_size must be divisible by local_minibatch_size"
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps" assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_torch_threads = args.torch_threads // args.world_size torch.set_num_threads(2)
local_env_threads = args.env_threads // args.world_size # torch.set_float32_matmul_precision('high')
torch.set_num_threads(local_torch_threads)
torch.set_float32_matmul_precision('high')
if args.world_size > 1: if args.world_size > 1:
torchrun_setup(args.backend, local_rank) dist.init_process_group('xla', init_method='xla://')
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
...@@ -197,55 +206,29 @@ def main(): ...@@ -197,55 +206,29 @@ def main():
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed - rank) torch.manual_seed(args.seed - rank)
if args.torch_deterministic: # if args.torch_deterministic:
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
else: # else:
torch.backends.cudnn.benchmark = True # torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() and args.cuda else "cpu") device = xm.xla_device()
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file) deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
# env setup # env setup
envs = ygoenv.make( envs = make_env(args, args.local_num_envs, args.local_env_threads)
task_id=args.env_id, obs_space = envs.env.observation_space
env_type="gymnasium", action_shape = envs.env.action_space.shape
num_envs=args.local_num_envs,
num_threads=local_env_threads,
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
)
envs.num_envs = args.local_num_envs
obs_space = envs.observation_space
action_shape = envs.action_space.shape
if local_rank == 0: if local_rank == 0:
fprint(f"obs_space={obs_space}, action_shape={action_shape}") fprint(f"obs_space={obs_space}, action_shape={action_shape}")
envs_per_thread = args.local_num_envs // local_env_threads envs_per_thread = args.local_num_envs // args.local_env_threads
local_eval_episodes = args.eval_episodes // args.world_size local_eval_episodes = args.eval_episodes // args.world_size
local_eval_num_envs = local_eval_episodes local_eval_num_envs = local_eval_episodes
eval_envs = ygoenv.make( local_eval_num_threads = max(1, local_eval_num_envs // envs_per_thread)
task_id=args.env_id, eval_envs = make_env(args, local_eval_num_envs, local_eval_num_threads, mode='bot')
env_type="gymnasium",
num_envs=local_eval_num_envs,
num_threads=max(1, local_eval_num_envs // envs_per_thread),
seed=args.seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode=args.play_mode,
)
eval_envs.num_envs = local_eval_num_envs
envs = RecordEpisodeStatistics(envs)
eval_envs = RecordEpisodeStatistics(eval_envs)
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -265,44 +248,48 @@ def main(): ...@@ -265,44 +248,48 @@ def main():
if args.embedding_file: if args.embedding_file:
agent.freeze_embeddings() agent.freeze_embeddings()
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device) if args.fix_target:
agent_t.eval() agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.load_state_dict(agent.state_dict()) agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
# if args.world_size > 1:
# ddp_agent = DDP(agent, gradient_as_bucket_view=True)
# else:
# ddp_agent = agent
optim_params = list(agent.parameters()) optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
def predict_step(agent: Agent, next_obs): def predict_step(agent: Agent, next_obs):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.fp16_eval): logits, value, valid = agent(next_obs)
logits, value, valid = agent(next_obs)
return logits, value return logits, value
from ygoai.rl.ppo import train_step from ygoai.rl.ppo import train_step_t as train_step
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here traced_model = torch.compile(agent, backend='openxla_eval')
# predict_step = torch.compile(predict_step, mode=args.compile) traced_model_t = traced_model
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device) train_step = torch.compile(train_step, backend='openxla')
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
else: else:
traced_model = agent traced_model = agent
traced_model_t = agent_t traced_model_t = agent_t
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device) obs_shape = get_obs_shape(obs_space)
actions = torch.zeros((args.collect_length, args.local_num_envs) + action_shape).to(device) obs = {
logprobs = torch.zeros((args.collect_length, args.local_num_envs)).to(device) key: np.zeros(
rewards = torch.zeros((args.collect_length, args.local_num_envs)).to(device) (args.collect_length, args.local_num_envs, *_obs_shape), dtype=obs_space[key].dtype)
dones = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device) for key, _obs_shape in obs_shape.items()
values = torch.zeros((args.collect_length, args.local_num_envs)).to(device) }
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device) actions = np.zeros((args.collect_length, args.local_num_envs) + action_shape, dtype=np.int64)
logprobs = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
rewards = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
dones = np.zeros((args.collect_length, args.local_num_envs), dtype=np.bool_)
values = np.zeros((args.collect_length, args.local_num_envs), dtype=np.float32)
learns = np.zeros((args.collect_length, args.local_num_envs), dtype=np.bool_)
avg_ep_returns = deque(maxlen=1000) avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
version = 0 version = 0
...@@ -312,73 +299,89 @@ def main(): ...@@ -312,73 +299,89 @@ def main():
warmup_steps = 0 warmup_steps = 0
start_time = time.time() start_time = time.time()
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_obs = to_tensor(next_obs, device, dtype=torch.uint8) next_obs_ = to_tensor(next_obs, device, dtype=torch.uint8)
next_to_play_ = info["to_play"] next_to_play = info["to_play"]
next_to_play = to_tensor(next_to_play_, device) next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) ai_player1 = np.concatenate([
ai_player1_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(ai_player1)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) next_value1 = next_value2 = 0
next_value = 0
step = 0 step = 0
for iteration in range(1, args.num_iterations + 1): for iteration in range(args.num_iterations):
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations frac = 1.0 - iteration / args.num_iterations
lrnow = frac * args.learning_rate lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow optimizer.param_groups[0]["lr"] = lrnow
model_time = 0 model_time = 0
env_time = 0 env_time = 0
o_time1 = 0
o_time2 = 0
collect_start = time.time() collect_start = time.time()
while step < args.collect_length: while step < args.collect_length:
global_step += args.num_envs global_step += args.num_envs
_start = time.time()
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
learn = next_to_play == ai_player1 learn = next_to_play == ai_player1
learns[step] = learn learns[step] = learn
o_time1 += time.time() - _start
_start = time.time() _start = time.time()
logits, value = predict_step(traced_model, next_obs) logits, value = predict_step(traced_model, next_obs_)
logits_t, value_t = predict_step(traced_model_t, next_obs) if args.fix_target:
logits = torch.where(learn[:, None], logits, logits_t) logits_t, value_t = predict_step(traced_model_t, next_obs)
value = torch.where(learn[:, None], value, value_t) logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
u = torch.rand_like(logits)
action = torch.argmax(logits - torch.log(-torch.log(u)), dim=1)
logprob = logits.log_softmax(dim=1).gather(-1, action[:, None]).squeeze(-1)
value = value.flatten() value = value.flatten()
probs = Categorical(logits=logits) xm.mark_step()
action = probs.sample() model_time += time.time() - _start
logprob = probs.log_prob(action)
_start = time.time()
logprob = logprob.cpu().numpy()
value = value.cpu().numpy()
action = action.cpu().numpy()
o_time2 += time.time() - _start
_start = time.time()
values[step] = value values[step] = value
actions[step] = action actions[step] = action
logprobs[step] = logprob logprobs[step] = logprob
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float() next_nonterminal = 1 - next_done.astype(np.float32)
next_value = torch.where(learn, value, next_value) * next_nonterminal next_value1 = np.where(learn, value, next_value1) * next_nonterminal
next_value2 = np.where(learn, next_value2, value) * next_nonterminal
o_time1 += time.time() - _start
_start = time.time() _start = time.time()
to_play = next_to_play_ to_play = next_to_play
next_obs, reward, next_done_, info = envs.step(action) next_obs, reward, next_done, info = envs.step(action)
next_to_play_ = info["to_play"] next_to_play = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
env_time += time.time() - _start env_time += time.time() - _start
rewards[step] = to_tensor(reward, device) _start = time.time()
next_obs, next_done = to_tensor(next_obs, device, torch.uint8), to_tensor(next_done_, device, torch.bool) rewards[step] = reward
o_time1 += time.time() - _start
_start = time.time()
next_obs_ = to_tensor(next_obs, device, torch.uint8)
o_time2 += time.time() - _start
step += 1 step += 1
if not writer: if not writer:
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done):
if d: if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1 pl = 1 if to_play[idx] == ai_player1[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
...@@ -387,7 +390,7 @@ def main(): ...@@ -387,7 +390,7 @@ def main():
if random.random() < args.log_p: if random.random() < args.log_p:
n = 100 n = 100
if random.random() < 10/n or iteration <= 2: if random.random() < 10/n or iteration <= 1:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}") fprint(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
...@@ -397,20 +400,23 @@ def main(): ...@@ -397,20 +400,23 @@ def main():
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
collect_time = time.time() - collect_start collect_time = time.time() - collect_start
if local_rank == 0: # if local_rank == 0:
fprint(f"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}") fprint(f"[Rank {rank}] collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}, o_time1={o_time1:.4f}, o_time2={o_time2:.4f}")
step = args.collect_length - args.num_steps step = args.collect_length - args.num_steps
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): with torch.no_grad():
value = traced_model(next_obs)[1].reshape(-1) value = predict_step(traced_model, next_obs_)[1].reshape(-1)
value_t = traced_model_t(next_obs)[1].reshape(-1) if args.fix_target:
value = torch.where(next_to_play == ai_player1, value, value_t) value_t = predict_step(traced_model_t, next_obs_)[1].reshape(-1)
nextvalues = torch.where(next_to_play == ai_player1, value, next_value) value = torch.where(next_to_play == ai_player1, value, value_t)
value = value.cpu().numpy()
if step > 0 and iteration != 1: nextvalues1 = np.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = np.where(next_to_play != ai_player1, value, next_value2)
if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
v_steps = args.local_minibatch_size * 4 // args.local_num_envs v_steps = args.local_minibatch_size * 4 // args.local_num_envs
for v_start in range(0, step, v_steps): for v_start in range(0, step, v_steps):
...@@ -423,11 +429,14 @@ def main(): ...@@ -423,11 +429,14 @@ def main():
value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1) value = predict_step(traced_model, v_obs)[1].reshape(v_end - v_start, -1)
values[v_start:v_end] = value values[v_start:v_end] = value
advantages = bootstrap_value_self( advantages = bootstrap_value_selfplay(
values, rewards, dones, learns, nextvalues, next_done, args.gamma, args.gae_lambda) values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, args.gamma, args.gae_lambda)
bootstrap_time = time.time() - _start bootstrap_time = time.time() - _start
_start = time.time() train_start = time.time()
d_time1 = 0
d_time2 = 0
d_time3 = 0
# flatten the batch # flatten the batch
b_obs = { b_obs = {
k: v[:args.num_steps].reshape((-1,) + v.shape[2:]) k: v[:args.num_steps].reshape((-1,) + v.shape[2:])
...@@ -437,32 +446,71 @@ def main(): ...@@ -437,32 +446,71 @@ def main():
b_logprobs = logprobs[:args.num_steps].reshape(-1) b_logprobs = logprobs[:args.num_steps].reshape(-1)
b_advantages = advantages[:args.num_steps].reshape(-1) b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1) b_values = values[:args.num_steps].reshape(-1)
b_learns = torch.ones_like(b_values, dtype=torch.bool) if args.learn_opponent else learns[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values b_returns = b_advantages + b_values
if args.fix_target:
b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = np.ones_like(b_values, dtype=np.bool_)
_start = time.time()
b_obs = to_tensor(b_obs, device=device, dtype=torch.uint8)
b_actions, b_logprobs, b_advantages, b_values, b_returns, b_learns = [
to_tensor(v, device) for v in [b_actions, b_logprobs, b_advantages, b_values, b_returns, b_learns]
]
d_time1 += time.time() - _start
agent.train()
model_time = 0
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
clipfracs = [] clipfracs = []
b_inds = np.arange(args.local_batch_size)
xm.mark_step()
for epoch in range(args.update_epochs): for epoch in range(args.update_epochs):
_start = time.time()
np.random.shuffle(b_inds) np.random.shuffle(b_inds)
for start in range(0, args.local_batch_size, args.local_minibatch_size): d_time2 += time.time() - _start
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, scaler, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds], b_learns[mb_inds], args)
reduce_gradidents(optim_params, args.world_size)
nn.utils.clip_grad_norm_(optim_params, args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
clipfracs.append(clipfrac.item())
if args.target_kl is not None and approx_kl > args.target_kl:
break
_start = time.time()
b_inds_ = to_tensor(b_inds, device=device)
n_mini_batches = args.local_batch_size // args.local_minibatch_size
b_inds_ = b_inds_.reshape(n_mini_batches, args.local_minibatch_size)
xm.mark_step()
d_time3 += time.time() - _start
for i in range(n_mini_batches):
_start = time.time()
mb_inds = b_inds_[i]
xm.mark_step()
d_time3 += time.time() - _start
_start = time.time()
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages,
b_returns, b_values, b_learns, mb_inds, args)
clipfracs.append(clipfrac)
xm.mark_step()
model_time += time.time() - _start
# mb_obs = {
# k: v[mb_inds] for k, v in b_obs.items()
# }
# mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns = [
# v[mb_inds] for v in [b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns]]
# xm.mark_step()
# old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
# train_step(ddp_agent_t, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages,
# mb_returns, mb_values, mb_learns, args)
# if rank == 0:
# # For short report that only contains a few key metrics.
# print(met.short_metrics_report())
# # For full report that includes all metrics.
# print(met.metrics_report())
# met.clear_all()
clipfrac = torch.stack(clipfracs).mean().item()
if step > 0: if step > 0:
# TODO: use cyclic buffer to avoid copying # TODO: use cyclic buffer to avoid copying
for v in obs.values(): for v in obs.values():
...@@ -470,16 +518,16 @@ def main(): ...@@ -470,16 +518,16 @@ def main():
for v in [actions, logprobs, rewards, dones, values, learns]: for v in [actions, logprobs, rewards, dones, values, learns]:
v[:step] = v[args.num_steps:].clone() v[:step] = v[args.num_steps:].clone()
train_time = time.time() - _start train_time = time.time() - train_start
if local_rank == 0: if local_rank == 0:
fprint(f"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}") fprint(f"d_time1={d_time1:.4f}, d_time2={d_time2:.4f}, d_time3={d_time3:.4f}")
fprint(f"train_time={train_time:.4f}, model_time={model_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}")
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true) var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
if rank == 0: if rank == 0:
if iteration % args.save_interval == 0: if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt")) torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent.pt"))
...@@ -490,13 +538,13 @@ def main(): ...@@ -490,13 +538,13 @@ def main():
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) writer.add_scalar("losses/clipfrac", clipfrac, global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step) writer.add_scalar("losses/explained_variance", explained_var, global_step)
SPS = int((global_step - warmup_steps) / (time.time() - start_time)) SPS = int((global_step - warmup_steps) / (time.time() - start_time))
# Warmup at first few iterations for accurate SPS measurement # Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters = 10 SPS_warmup_iters = 5
if iteration == SPS_warmup_iters: if iteration == SPS_warmup_iters:
start_time = time.time() start_time = time.time()
warmup_steps = global_step warmup_steps = global_step
...@@ -506,41 +554,44 @@ def main(): ...@@ -506,41 +554,44 @@ def main():
if rank == 0: if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
if rank == 0: if args.fix_target:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0: if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt")) should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}") should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
avg_win_rates.clear() else:
avg_ep_returns.clear() should_update = torch.zeros((), dtype=torch.int64, device=device)
_start = time.time()
eval_return = evaluate(
eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# sync the statistics
if args.world_size > 1: if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
eval_return = eval_stats.cpu().numpy() should_update = should_update.item() > 0
if rank == 0: if should_update:
writer.add_scalar("charts/eval_return", eval_return, global_step) agent_t.load_state_dict(agent.state_dict())
if local_rank == 0: with torch.no_grad():
eval_time = time.time() - _start traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}") traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
# if args.eval_interval and iteration % args.eval_interval == 0:
# # Eval with rule-based policy
# _start = time.time()
# eval_return = evaluate(
# eval_envs, traced_model, local_eval_episodes, device, args.fp16_eval)[0]
# eval_stats = torch.tensor(eval_return, dtype=torch.float32, device=device)
# # sync the statistics
# if args.world_size > 1:
# dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
# eval_return = eval_stats.cpu().numpy()
# if rank == 0:
# writer.add_scalar("charts/eval_return", eval_return, global_step)
# if local_rank == 0:
# eval_time = time.time() - _start
# fprint(f"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}")
# Eval with old model # Eval with old model
...@@ -553,4 +604,8 @@ def main(): ...@@ -553,4 +604,8 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() world_size = int(os.getenv("WORLD_SIZE", "1"))
if world_size == 1:
_mp_fn(0, 1)
else:
xmp.spawn(_mp_fn, args=(world_size,))
...@@ -95,6 +95,69 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad ...@@ -95,6 +95,69 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad
return -jnp.mean(clipped_objective * mask) return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_values, next_values, lastgaelam = carry
next_done, curvalues, reward, switch = inp
nextnonterminal = 1.0 - next_done
next_values = jnp.where(switch, -pred_values, next_values)
lastgaelam = jnp.where(switch, 0, lastgaelam)
delta = reward + gamma * next_values * nextnonterminal - curvalues
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
return (pred_values, curvalues, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, lastgaelam
_, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_upgo_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_value, next_value, next_q, last_return, lastgaelam = carry
next_done, curvalues, reward, switch = inp
gamma_ = gamma * (1.0 - next_done)
next_value = jnp.where(switch, -pred_value, next_value)
next_q = jnp.where(switch, -pred_value, next_q)
last_return = jnp.where(switch, -pred_value, last_return)
lastgaelam = jnp.where(switch, 0, lastgaelam)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
delta = next_q - curvalues
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
carry = pred_value, next_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
return returns - values, advantages + values
def compute_gae_once(carry, inp, gamma, gae_lambda): def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
......
...@@ -320,3 +320,62 @@ class PPOAgent(nn.Module): ...@@ -320,3 +320,62 @@ class PPOAgent(nn.Module):
logits = actor(f_state, f_actions, mask) logits = actor(f_state, f_actions, mask)
value = critic(f_state) value = critic(f_state)
return logits, value, valid return logits, value, valid
class PPOLSTMAgent(nn.Module):
channels: int = 128
num_layers: int = 2
lstm_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
@nn.compact
def __call__(self, inputs):
if self.multi_step:
# (num_steps * batch_size, ...)
carry1, carry2, x, done, switch = inputs
batch_size = carry1[0].shape[0]
num_steps = done.shape[0] // batch_size
else:
carry, x = inputs
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
f_actions, f_state, mask, valid = encoder(x)
lstm_layer = nn.OptimizedLSTMCell(
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.multi_step:
def body_fn(cell, carry, x, done, switch):
carry, init_carry = carry
carry, y = cell(carry, x)
carry = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), carry)
carry = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_carry, carry)
return (carry, init_carry), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state, done, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch))
carry, f_state = scan(lstm_layer, (carry1, carry2), f_state, done, switch)
f_state = f_state.reshape((-1, f_state.shape[-1]))
else:
carry, f_state = lstm_layer(carry, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return carry, logits, value, valid
import numpy as np import numpy as np
def evaluate(envs, act_fn, params): def evaluate(envs, act_fn, params, rnn_state=None):
num_episodes = envs.num_envs num_episodes = envs.num_envs
episode_lengths = [] episode_lengths = []
episode_rewards = [] episode_rewards = []
eval_win_rates = [] eval_win_rates = []
obs = envs.reset()[0] obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_)
while True: while True:
actions = act_fn(params, obs) if rnn_state is None:
actions = act_fn(params, obs)
else:
rnn_state, actions = act_fn(params, (rnn_state, obs))
actions = np.array(actions) actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions) obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if not d: if not d or collected[idx]:
continue continue
collected[idx] = True
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
......
...@@ -16,7 +16,7 @@ def entropy_from_logits(logits): ...@@ -16,7 +16,7 @@ def entropy_from_logits(logits):
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args): def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train): with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs) logits, newvalue, valid = agent(mb_obs)[:3]
logits = logits - logits.logsumexp(dim=-1, keepdim=True) logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1) newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits) entropy = entropy_from_logits(logits)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment