Commit 2bf8ce6a authored by sbl1996@126.com's avatar sbl1996@126.com

Add jax and lstm

parent 096e743e
This diff is collapsed.
......@@ -14,18 +14,12 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
......@@ -41,7 +35,7 @@ class Args:
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
......@@ -50,8 +44,6 @@ class Args:
"""the player to play as, -1 means random, 0 is the first player, 1 is the second player"""
play: bool = False
"""whether to play the game"""
selfplay: bool = False
"""whether to use selfplay"""
record: bool = False
"""whether to record the game as YGOPro replays"""
......@@ -67,27 +59,36 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used"""
agent: bool = False
"""whether to use the agent"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load"""
checkpoint: Optional[str] = None
"""the checkpoint to load, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = True
"""if toggled, the model will be optimized"""
convert: bool = False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
if __name__ == "__main__":
args = tyro.cli(Args)
......@@ -102,7 +103,6 @@ if __name__ == "__main__":
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
......@@ -113,15 +113,20 @@ if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
if args.agent:
if args.checkpoint and args.framework is None:
args.framework = "jax" if "flax_model" in args.checkpoint else "torch"
if args.framework == "torch":
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
else:
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
......@@ -136,15 +141,22 @@ if __name__ == "__main__":
player=args.player,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='human' if args.play else ('self' if args.selfplay else ('bot' if args.bot_type == "greedy" else "random")),
play_mode='human' if args.play else ('bot' if args.bot_type == "greedy" else "random"),
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
if args.agent:
if args.checkpoint and args.checkpoint.endswith(".ptj"):
if args.framework == 'torch':
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
if args.checkpoint.endswith(".ptj"):
agent = torch.jit.load(args.checkpoint)
else:
# count lines of code_list
......@@ -155,12 +167,11 @@ if __name__ == "__main__":
embedding_shape = len(code_list)
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
state_dict = torch.load(args.checkpoint, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
if args.convert:
......@@ -191,6 +202,48 @@ if __name__ == "__main__":
print(f"Optimized model saved to {args.checkpoint}j")
exit(0)
def predict_fn(obs):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits = agent(obs)[0]
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
return probs
else:
import jax
import jax.numpy as jnp
import flax
from ygoai.rl.jax.agent2 import PPOAgent
def create_agent(args):
return PPOAgent(
channels=128,
num_layers=2,
embedding_shape=args.num_embeddings,
)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(obs):
probs = get_probs(params, obs)
return np.array(probs)
print(f"loaded checkpoint from {args.checkpoint}")
obs, infos = envs.reset()
next_to_play = infos['to_play']
......@@ -210,16 +263,11 @@ if __name__ == "__main__":
start_step = step
model_time = env_time = 0
if args.agent:
if args.framework:
_start = time.time()
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
probs = predict_fn(obs)
if args.verbose:
print([f"{p:.4f}" for p in probs[probs != 0].tolist()])
print(f"{values[0].item():.4f}")
actions = probs.argmax(axis=1)
model_time += time.time() - _start
else:
......@@ -228,13 +276,6 @@ if __name__ == "__main__":
else:
actions = np.zeros(num_envs, dtype=np.int32)
# for k, v in obs.items():
# v = v[0]
# if k == 'cards_':
# v = np.concatenate([np.arange(v.shape[0])[:, None], v], axis=1)
# print(k, v.tolist())
# print(infos)
# print(actions[0])
to_play = next_to_play
_start = time.time()
......@@ -249,15 +290,7 @@ if __name__ == "__main__":
win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx]
if args.selfplay:
pl = 1 if to_play[idx] == 0 else -1
winner = 0 if episode_reward * pl > 0 else 1
win = 1 - winner
else:
if episode_reward < 0:
win = 0
else:
win = 1
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -74,9 +74,16 @@ class Args:
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.98
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
fix_target: bool = False
"""if toggled, the target network will be fixed"""
update_win_rate: float = 0.55
"""the required win rate to update the agent"""
update_return: float = 0.1
"""the required return to update the agent"""
minibatch_size: int = 256
"""the mini-batch size"""
update_epochs: int = 2
......@@ -93,8 +100,6 @@ class Args:
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
......@@ -169,6 +174,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
......@@ -247,6 +253,13 @@ def main():
if args.embedding_file:
agent.freeze_embeddings()
if args.fix_target:
agent_t = Agent(args.num_channels, L, L, embedding_shape).to(device)
agent_t.eval()
agent_t.load_state_dict(agent.state_dict())
else:
agent_t = agent
optim_params = list(agent.parameters())
optimizer = optim.Adam(optim_params, lr=args.learning_rate, eps=1e-5)
......@@ -265,10 +278,16 @@ def main():
example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
if args.fix_target:
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
else:
traced_model_t = traced_model
train_step = torch.compile(train_step, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
......@@ -280,6 +299,7 @@ def main():
learns = torch.zeros((args.collect_length, args.local_num_envs), dtype=torch.bool).to(device)
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
version = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
......@@ -296,7 +316,6 @@ def main():
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
next_value1 = next_value2 = 0
step = 0
for iteration in range(args.num_iterations):
......@@ -320,6 +339,10 @@ def main():
_start = time.time()
logits, value = predict_step(traced_model, next_obs)
if args.fix_target:
logits_t, value_t = predict_step(traced_model_t, next_obs)
logits = torch.where(learn[:, None], logits, logits_t)
value = torch.where(learn[:, None], value, value_t)
value = value.flatten()
probs = Categorical(logits=logits)
action = probs.sample()
......@@ -331,10 +354,6 @@ def main():
action = action.cpu().numpy()
model_time += time.time() - _start
next_nonterminal = 1 - next_done.float()
next_value1 = torch.where(learn, value, next_value1) * next_nonterminal
next_value2 = torch.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play_
next_obs, reward, next_done_, info = envs.step(action)
......@@ -378,8 +397,12 @@ def main():
# bootstrap value if not done
with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
nextvalues1 = torch.where(next_to_play == ai_player1, value, -value)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != ai_player1, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
......@@ -409,10 +432,10 @@ def main():
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
if args.fix_target:
b_learns = learns[:args.num_steps].reshape(-1)
else:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
# Optimizing the policy and value network
b_inds = np.arange(args.local_batch_size)
......@@ -476,6 +499,28 @@ def main():
if rank == 0:
writer.add_scalar("charts/SPS", SPS, global_step)
if args.fix_target:
if rank == 0:
should_update = len(avg_win_rates) == 1000 and np.mean(avg_win_rates) > args.update_win_rate and np.mean(avg_ep_returns) > args.update_return
should_update = torch.tensor(int(should_update), dtype=torch.int64, device=device)
else:
should_update = torch.zeros((), dtype=torch.int64, device=device)
if args.world_size > 1:
dist.all_reduce(should_update, op=dist.ReduceOp.SUM)
should_update = should_update.item() > 0
if should_update:
agent_t.load_state_dict(agent.state_dict())
with torch.no_grad():
traced_model_t = torch.jit.trace(agent_t, (example_obs,), check_tolerance=False, check_trace=False)
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
version += 1
if rank == 0:
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"agent_v{version}.pt"))
print(f"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}")
avg_win_rates.clear()
avg_ep_returns.clear()
if args.eval_interval and iteration % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
......
This diff is collapsed.
......@@ -69,11 +69,11 @@ class Args:
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.98
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
update_win_rate: float = 0.55
......@@ -103,8 +103,6 @@ class Args:
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
learn_opponent: bool = False
"""if toggled, the samples from the opponent will be used to train the agent"""
collect_length: Optional[int] = None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
......@@ -145,6 +143,8 @@ class Args:
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
......@@ -158,7 +158,7 @@ def make_env(args, num_envs, num_threads, mode='self'):
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
......@@ -181,6 +181,7 @@ def main():
args.local_minibatch_size = int(args.minibatch_size // args.world_size)
args.batch_size = int(args.num_envs * args.num_steps)
args.num_iterations = args.total_timesteps // args.batch_size
args.num_minibatches = args.local_batch_size // args.local_minibatch_size
args.env_threads = args.env_threads or args.num_envs
args.torch_threads = args.torch_threads or (int(os.getenv("OMP_NUM_THREADS", "2")) * args.world_size)
args.collect_length = args.collect_length or args.num_steps
......@@ -473,7 +474,7 @@ def main():
b_advantages = advantages[:args.num_steps].reshape(-1)
b_values = values[:args.num_steps].reshape(-1)
b_returns = b_advantages + b_values
if args.learn_opponent or selfplay:
if selfplay:
b_learns = torch.ones_like(b_values, dtype=torch.bool)
else:
b_learns = learns[:args.num_steps].reshape(-1)
......
This diff is collapsed.
......@@ -95,6 +95,69 @@ def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_grad
return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_values, next_values, lastgaelam = carry
next_done, curvalues, reward, switch = inp
nextnonterminal = 1.0 - next_done
next_values = jnp.where(switch, -pred_values, next_values)
lastgaelam = jnp.where(switch, 0, lastgaelam)
delta = reward + gamma * next_values * nextnonterminal - curvalues
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
return (pred_values, curvalues, lastgaelam), lastgaelam
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, lastgaelam
_, advantages = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(6, 7))
def compute_gae_upgo_2p0s(
next_value, next_done, values, rewards, dones, switch,
gamma, gae_lambda,
):
def body_fn(carry, inp):
pred_value, next_value, next_q, last_return, lastgaelam = carry
next_done, curvalues, reward, switch = inp
gamma_ = gamma * (1.0 - next_done)
next_value = jnp.where(switch, -pred_value, next_value)
next_q = jnp.where(switch, -pred_value, next_q)
last_return = jnp.where(switch, -pred_value, last_return)
lastgaelam = jnp.where(switch, 0, lastgaelam)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
delta = next_q - curvalues
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
carry = pred_value, next_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_value, next_value, next_value, lastgaelam
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (dones[1:], values, rewards, switch), reverse=True
)
return returns - values, advantages + values
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
......
......@@ -320,3 +320,62 @@ class PPOAgent(nn.Module):
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return logits, value, valid
class PPOLSTMAgent(nn.Module):
channels: int = 128
num_layers: int = 2
lstm_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
@nn.compact
def __call__(self, inputs):
if self.multi_step:
# (num_steps * batch_size, ...)
carry1, carry2, x, done, switch = inputs
batch_size = carry1[0].shape[0]
num_steps = done.shape[0] // batch_size
else:
carry, x = inputs
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
f_actions, f_state, mask, valid = encoder(x)
lstm_layer = nn.OptimizedLSTMCell(
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.multi_step:
def body_fn(cell, carry, x, done, switch):
carry, init_carry = carry
carry, y = cell(carry, x)
carry = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), carry)
carry = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_carry, carry)
return (carry, init_carry), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state, done, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch))
carry, f_state = scan(lstm_layer, (carry1, carry2), f_state, done, switch)
f_state = f_state.reshape((-1, f_state.shape[-1]))
else:
carry, f_state = lstm_layer(carry, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return carry, logits, value, valid
import numpy as np
def evaluate(envs, act_fn, params):
def evaluate(envs, act_fn, params, rnn_state=None):
num_episodes = envs.num_envs
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_)
while True:
actions = act_fn(params, obs)
if rnn_state is None:
actions = act_fn(params, obs)
else:
rnn_state, actions = act_fn(params, (rnn_state, obs))
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if not d:
if not d or collected[idx]:
continue
collected[idx] = True
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
......
......@@ -16,7 +16,7 @@ def entropy_from_logits(logits):
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
logits, newvalue, valid = agent(mb_obs)[:3]
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment