Commit 4f2ad15b authored by sbl1996@126.com's avatar sbl1996@126.com

Add cleanba PPO for TPU

parent 50353ff4
......@@ -135,6 +135,8 @@ class Args:
"""the number of iterations (computed in runtime)"""
world_size: int = 0
"""the number of processes (computed in runtime)"""
num_embeddings: Optional[int] = None
"""the number of embeddings (computed in runtime)"""
def make_env(args, num_envs, num_threads, mode='self'):
......@@ -148,7 +150,7 @@ def make_env(args, num_envs, num_threads, mode='self'):
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
play_mode=mode,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
......
......@@ -221,13 +221,8 @@ def actor(
return logits, value
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
predict_step = torch.compile(predict_step, mode=args.compile)
agent_r = agent
# example_obs = create_obs(envs.observation_space, (args.local_num_envs,), device=device)
# with torch.no_grad():
# agent_r = torch.jit.trace(agent, (example_obs,), check_tolerance=False, check_trace=False)
else:
agent_r = agent
......
This diff is collapsed.
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns
)
self.returned_episode_lengths = np.where(
dones, self.episode_lengths, self.returned_episode_lengths
)
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
\ No newline at end of file
This diff is collapsed.
import numpy as np
def evaluate(envs, act_fn, params):
num_episodes = envs.num_envs
episode_lengths = []
episode_rewards = []
eval_win_rates = []
obs = envs.reset()[0]
while True:
actions = act_fn(params, obs)
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
for idx, d in enumerate(dones):
if not d:
continue
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
This diff is collapsed.
import jax.numpy as jnp
from ygoai.rl.env import RecordEpisodeStatistics
def masked_mean(x, valid):
x = jnp.where(valid, x, jnp.zeros_like(x))
return x.sum() / valid.sum()
def masked_normalize(x, valid, epsilon=1e-8):
x = jnp.where(valid, x, jnp.zeros_like(x))
n = valid.sum()
mean = x.sum() / n
variance = jnp.square(x - mean).sum() / n
return (x - mean) / jnp.sqrt(variance + epsilon)
\ No newline at end of file
......@@ -49,8 +49,11 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if scaler is None:
loss.backward()
else:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
......
......@@ -6,55 +6,7 @@ import pickle
import optree
import torch
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations, infos = self.env.reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
class CompatEnv(gym.Wrapper):
def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = self.env.step(action)
dones = np.logical_or(terminated, truncated)
return (
observations,
rewards,
dones,
infos,
)
from ygoai.rl.env import RecordEpisodeStatistics
def split_param_groups(model, regex):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment