Commit 9a2a21e3 authored by sbl1996@126.com's avatar sbl1996@126.com

Merge branch 'enhance_history'

parents d8cbf274 e79a43ef
*.pt
*.ptj
*.pkl
# Xmake cache
.xmake/
......
......@@ -88,4 +88,7 @@
## History Actions
- 0,1: card id, uint16 -> 2 uint8
- others same as legal actions
- 2-12 same as legal actions
- 13: player, discrete, 0: me, 1: oppo
- 14: turn, discrete, trunc to 3
This diff is collapsed.
......@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
......@@ -41,7 +35,7 @@ class Args:
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
......@@ -50,8 +44,6 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
record: bool = False
"""whether to record the game as YGOPro replays"""
......@@ -67,27 +59,36 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load"""
checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
if __name__ == "__main__":
args = tyro.cli(Args)
......@@ -102,7 +103,6 @@ if __name__ == "__main__":
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
......@@ -113,15 +113,20 @@ if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
if args.agent:
if args.checkpoint and args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint else "torch"
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
else:
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
......@@ -136,15 +141,22 @@ if __name__ == "__main__":
player=args.player,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
play_mode='human' if args.play else ('bot' if args.bot_type == "greedy" else "random"),
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.agent:
if args.checkpoint and args.checkpoint.endswith(".ptj"):
if args.framework == 'torch':
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
......@@ -154,13 +166,12 @@ if __name__ == "__main__":
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, 2, embedding_shape).to(device)
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
......@@ -191,6 +202,48 @@ if __name__ == "__main__":
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
def predict_fn(obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits = agent(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
return probs
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(obs):
probs = get_probs(params, obs)
return np.array(probs)
print(f"loaded checkpoint from {args.checkpoint}")
obs, infos = envs.reset()
next_to_play = infos['to_play']
......@@ -210,16 +263,11 @@ if __name__ == "__main__":
start_step = step
model_time = env_time = 0
if args.agent:
if args.framework:
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
probs = predict_fn(obs)
if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
print(f"{values[0].item():.4f}")
actions = probs.argmax(axis=1)
model_time += time.time() - _start
else:
......@@ -228,13 +276,6 @@ if __name__ == "__main__":
else:
actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play = next_to_play
_start = time.time()
......@@ -249,15 +290,7 @@ if __name__ == "__main__":
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward < 0:
win = 0
else:
win = 1
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -44,11 +44,9 @@ class PositionalEncoding(nn.Module):
class Encoder(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False, affine=True):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False, affine=True):
super(Encoder, self).__init__()
self.channels = channels
self.num_history_action_layers = num_history_action_layers
c = channels
self.loc_embed = nn.Embedding(9, c)
......@@ -165,11 +163,17 @@ class Encoder(nn.Module):
for i in range(num_action_layers)
])
self.action_history_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_pe = PositionalEncoding(c, dropout=0.0)
self.history_action_net = nn.ModuleList([
nn.TransformerEncoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True)
for i in range(num_action_layers)
])
self.action_history_net = nn.ModuleList([
nn.TransformerDecoderLayer(
c, num_heads, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
for i in range(num_history_action_layers)
for i in range(num_action_layers)
])
self.action_norm = nn.LayerNorm(c, elementwise_affine=False)
......@@ -287,6 +291,7 @@ class Encoder(nn.Module):
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
batch_size = x_cards.shape[0]
x_cards_1 = x_cards[:, :, :12].long()
x_cards_2 = x_cards[:, :, 12:].to(torch.float32)
......@@ -294,7 +299,10 @@ class Encoder(nn.Module):
x_id = self.encode_card_id(x_cards_1[:, :, :2])
x_id = self.id_norm(x_id)
f_loc = self.loc_norm(self.loc_embed(x_cards_1[:, :, 2]))
x_loc = x_cards_1[:, :, 2]
c_mask = x_loc == 0
c_mask[:, 0] = False
f_loc = self.loc_norm(self.loc_embed(x_loc))
f_seq = self.seq_norm(self.seq_embed(x_cards_1[:, :, 3]))
x_feat1 = self.encode_card_feat1(x_cards_1)
......@@ -306,11 +314,14 @@ class Encoder(nn.Module):
f_cards = torch.cat([x_id, x_feat], dim=-1)
f_cards = f_cards + f_loc + f_seq
f_na_card = self.na_card_embed.expand(f_cards.shape[0], -1, -1)
for layer in self.card_net:
# f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_cards = layer(f_cards, src_key_padding_mask=c_mask)
f_na_card = self.na_card_embed.expand(batch_size, -1, -1)
f_cards = torch.cat([f_na_card, f_cards], dim=1)
# TODO: we can't use it because cudagraph says complex memory
# c_mask = torch.cat([torch.zeros(batch_size, 1, dtype=c_mask.dtype, device=c_mask.device), c_mask], dim=1)
for layer in self.card_net:
f_cards = layer(f_cards)
f_cards = self.card_norm(f_cards)
x_global = self.encode_global(x_global)
......@@ -332,23 +343,27 @@ class Encoder(nn.Module):
mask = x_actions[:, :, 2] == 0 # msg == 0
valid = x['global_'][:, -1] == 0
mask[:, 0] &= valid
mask[:, 0] = False
# mask[:, 0] &= valid
for layer in self.action_card_net:
f_actions = layer(f_actions, f_cards, tgt_key_padding_mask=mask)
if self.num_history_action_layers != 0:
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.action_history_pe(f_h_actions)
for layer in self.action_history_net:
f_actions = layer(f_actions, f_h_actions)
f_actions = layer(
f_actions, f_cards[:, 1:], tgt_key_padding_mask=mask, memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_']
x_h_actions = x_h_actions.long()
x_h_id = self.get_h_action_card_(x_h_actions[..., :2])
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask[:, 0] = False
x_h_a_feats = self.encode_action_(x_h_actions[:, :, 2:])
x_h_a_feats = torch.cat(x_h_a_feats, dim=-1)
f_h_actions = self.h_id_norm(x_h_id) + self.h_a_feat_norm(x_h_a_feats)
f_h_actions = self.history_action_pe(f_h_actions)
for layer in self.history_action_net:
f_h_actions = layer(f_h_actions, src_key_padding_mask=h_mask)
for layer in self.action_history_net:
f_actions = layer(
f_actions, f_h_actions, tgt_key_padding_mask=mask, memory_key_padding_mask=h_mask)
f_actions = self.action_norm(f_actions)
......@@ -385,13 +400,12 @@ class Actor(nn.Module):
class PPOAgent(nn.Module):
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2,
num_history_action_layers=2, embedding_shape=None, bias=False,
def __init__(self, channels=128, num_card_layers=2, num_action_layers=2, embedding_shape=None, bias=False,
affine=True, a_trans=True):
super(PPOAgent, self).__init__()
self.encoder = Encoder(
channels, num_card_layers, num_action_layers, num_history_action_layers, embedding_shape, bias, affine)
channels, num_card_layers, num_action_layers, embedding_shape, bias, affine)
c = channels
self.actor = Actor(c, a_trans)
......
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
return self.update_stats_and_infos(*super().step(action))
def update_stats_and_infos(self, *args):
observations, rewards, terminated, truncated, infos = args
dones = np.logical_or(terminated, truncated)
self.episode_returns += infos.get("reward", rewards)
self.episode_lengths += 1
self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns
)
self.returned_episode_lengths = np.where(
dones, self.episode_lengths, self.returned_episode_lengths
)
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return (
observations,
rewards,
dones,
infos,
)
def async_reset(self):
self.env.async_reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def recv(self):
return self.update_stats_and_infos(*self.env.recv())
def send(self, action):
return self.env.send(action)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import numpy as np
def evaluate(envs, act_fn, params, rnn_state=None):
num_episodes = envs.num_envs
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_)
while True:
if rnn_state is None:
actions = act_fn(params, obs)
else:
rnn_state, actions = act_fn(params, (rnn_state, obs))
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if not d or collected[idx]:
continue
collected[idx] = True
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x))
return x.sum() / valid.sum()
def masked_normalize(x, valid, epsilon=1e-8):
x = jnp.where(valid, x, jnp.zeros_like(x))
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment