Commit 330ee6af authored by sbl1996@126.com's avatar sbl1996@126.com

Add truncated LSTM

parent cd59a6e9
...@@ -26,8 +26,10 @@ class Args: ...@@ -26,8 +26,10 @@ class Args:
seed: int = 1 seed: int = 1
"""the random seed""" """the random seed"""
env_id: str = "YGOPro-v1" env_id1: str = "YGOPro-v1"
"""the id of the environment""" """the id of the environment1"""
env_id2: Optional[str] = None
"""the id of the environment2, defaults to `env_id1`"""
deck: str = "../assets/deck" deck: str = "../assets/deck"
"""the deck file to use""" """the deck file to use"""
deck1: Optional[str] = None deck1: Optional[str] = None
...@@ -40,10 +42,16 @@ class Args: ...@@ -40,10 +42,16 @@ class Args:
"""the language to use""" """the language to use"""
max_options: int = 24 max_options: int = 24
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 32 n_history_actions1: int = 32
"""the number of history actions to use""" """the number of history actions to use for the environment1"""
n_history_actions2: Optional[int] = None
"""the number of history actions to use for the environment2, defaults to `n_history_actions1`"""
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
accurate: bool = True
"""whether to do accurate evaluation. If not, there will be more short games"""
reverse: bool = False
"""whether to reverse the order of the agents"""
verbose: bool = False verbose: bool = False
"""whether to print debug information""" """whether to print debug information"""
...@@ -101,56 +109,94 @@ if __name__ == "__main__": ...@@ -101,56 +109,94 @@ if __name__ == "__main__":
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file) if args.env_id2 is None:
args.env_id2 = args.env_id1
if args.n_history_actions2 is None:
args.n_history_actions2 = args.n_history_actions1
args.deck1 = args.deck1 or deck cross_env = args.env_id1 != args.env_id2
args.deck2 = args.deck2 or deck env_id1 = args.env_id1
env_id2 = args.env_id2
seed = args.seed deck1 = init_ygopro(env_id1, args.lang, args.deck, args.code_list_file)
if not cross_env:
deck2 = deck1
else:
deck2 = init_ygopro(env_id2, args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck1
args.deck2 = args.deck2 or deck2
seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
if args.xla_device is not None: if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device) os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
if args.accurate:
if args.num_envs != args.num_episodes:
args.num_envs = args.num_episodes
print("Set num_envs to num_episodes for accurate evaluation")
num_envs = args.num_envs num_envs = args.num_envs
envs = ygoenv.make( env_option = dict(
task_id=args.env_id,
env_type="gymnasium", env_type="gymnasium",
num_envs=num_envs, num_envs=num_envs,
num_threads=args.env_threads, num_threads=args.env_threads,
seed=seed, seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=-1, player=-1,
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self', play_mode='self',
async_reset=False, async_reset=False,
verbose=args.verbose, verbose=args.verbose,
record=args.record, record=args.record,
) )
obs_space = envs.observation_space envs1 = ygoenv.make(
envs.num_envs = num_envs task_id=env_id1,
envs = RecordEpisodeStatistics(envs) n_history_actions=args.n_history_actions1,
deck1=args.deck1,
deck2=args.deck2,
**env_option,
)
if cross_env:
envs2 = ygoenv.make(
task_id=env_id2,
n_history_actions=args.n_history_actions2,
deck1=deck2,
deck2=deck2,
**env_option,
)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
obs_space1 = envs1.observation_space
envs1.num_envs = num_envs
envs1 = RecordEpisodeStatistics(envs1)
sample_obs1 = jax.tree.map(lambda x: jnp.array([x]), obs_space1.sample())
agent1 = create_agent1(args) agent1 = create_agent1(args)
rstate = agent1.init_rnn_state(1) rstate1 = agent1.init_rnn_state(1)
params1 = jax.jit(agent1.init)(agent_key, sample_obs, rstate) params1 = jax.jit(agent1.init)(key, sample_obs1, rstate1)
with open(args.checkpoint1, "rb") as f: with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params1, f.read()) params1 = flax.serialization.from_bytes(params1, f.read())
if cross_env:
obs_space2 = envs2.observation_space
envs2.num_envs = num_envs
envs2 = RecordEpisodeStatistics(envs2)
sample_obs2 = jax.tree.map(lambda x: jnp.array([x]), obs_space2.sample())
else:
sample_obs2 = sample_obs1
if args.checkpoint1 == args.checkpoint2: if args.checkpoint1 == args.checkpoint2:
params2 = params1 params2 = params1
else: else:
agent2 = create_agent2(args) agent2 = create_agent2(args)
rstate = agent2.init_rnn_state(1) rstate2 = agent2.init_rnn_state(1)
params2 = jax.jit(agent2.init)(agent_key, sample_obs, rstate) params2 = jax.jit(agent2.init)(key, sample_obs2, rstate2)
with open(args.checkpoint2, "rb") as f: with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params2, f.read()) params2 = flax.serialization.from_bytes(params2, f.read())
...@@ -169,11 +215,11 @@ if __name__ == "__main__": ...@@ -169,11 +215,11 @@ if __name__ == "__main__":
next_rstate = jnp.where(done[:, None], 0, next_rstate) next_rstate = jnp.where(done[:, None], 0, next_rstate)
return next_rstate, probs return next_rstate, probs
if args.num_envs != 1: if num_envs != 1:
@jax.jit @jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done): def get_probs2(params1, params2, rstate1, rstate2, obs1, obs2, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs, None, 1) next_rstate1, probs1 = get_probs(params1, rstate1, obs1, None, 1)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, None, 2) next_rstate2, probs2 = get_probs(params2, rstate2, obs2, None, 2)
probs = jnp.where(main[:, None], probs1, probs2) probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map( rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1) lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
...@@ -183,19 +229,26 @@ if __name__ == "__main__": ...@@ -183,19 +229,26 @@ if __name__ == "__main__":
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2)) lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, probs return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done): def predict_fn(rstate1, rstate2, obs1, obs2, main, done):
rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs, main, done) rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs1, obs2, main, done)
return rstate1, rstate2, np.array(probs) return rstate1, rstate2, np.array(probs)
else: else:
def predict_fn(rstate1, rstate2, obs, main, done): def predict_fn(rstate1, rstate2, obs1, obs2, main, done):
if main[0]: if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done, 1) rstate1, probs = get_probs(params1, rstate1, obs1, done, 1)
else: else:
rstate2, probs = get_probs(params2, rstate2, obs, done, 2) rstate2, probs = get_probs(params2, rstate2, obs2, done, 2)
return rstate1, rstate2, np.array(probs) return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset() obs1, infos1 = envs1.reset()
next_to_play = infos['to_play'] next_to_play1 = infos1['to_play']
if cross_env:
obs2, infos2 = envs2.reset()
next_to_play2 = infos2['to_play']
np.testing.assert_array_equal(next_to_play1, next_to_play2)
else:
obs2 = obs1
dones = np.zeros(num_envs, dtype=np.bool_) dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = [] episode_rewards = []
...@@ -209,12 +262,17 @@ if __name__ == "__main__": ...@@ -209,12 +262,17 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
main_player = np.concatenate([ first_player = np.zeros(num_envs // 2, dtype=np.int64)
np.zeros(num_envs // 2, dtype=np.int64), second_player = np.ones(num_envs - num_envs // 2, dtype=np.int64)
np.ones(num_envs - num_envs // 2, dtype=np.int64) if args.reverse:
]) main_player = np.concatenate([second_player, first_player])
else:
main_player = np.concatenate([first_player, second_player])
# main_player = np.zeros(num_envs, dtype=np.int64)
# main_player = np.ones(num_envs, dtype=np.int64)
rstate1 = agent1.init_rnn_state(num_envs) rstate1 = agent1.init_rnn_state(num_envs)
rstate2 = agent2.init_rnn_state(num_envs) rstate2 = agent2.init_rnn_state(num_envs)
collected = np.zeros((args.num_episodes,), dtype=np.bool_)
if not args.verbose: if not args.verbose:
pbar = tqdm(total=args.num_episodes) pbar = tqdm(total=args.num_episodes)
...@@ -227,35 +285,45 @@ if __name__ == "__main__": ...@@ -227,35 +285,45 @@ if __name__ == "__main__":
model_time = env_time = 0 model_time = env_time = 0
_start = time.time() _start = time.time()
main = next_to_play == main_player main = next_to_play1 == main_player
rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs, main, dones) rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs1, obs2, main, dones)
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
to_play = next_to_play to_play1 = next_to_play1
_start = time.time() _start = time.time()
obs, rewards, dones, infos = envs.step(actions) obs1, rewards1, dones1, infos1 = envs1.step(actions)
next_to_play = infos['to_play'] next_to_play1 = infos1['to_play']
if cross_env:
obs2, rewards2, dones2, infos2 = envs2.step(actions)
next_to_play2 = infos2['to_play']
np.testing.assert_array_equal(next_to_play1, next_to_play2)
np.testing.assert_array_equal(dones1, dones2)
else:
obs2 = obs1
env_time += time.time() - _start env_time += time.time() - _start
step += 1 step += 1
for idx, d in enumerate(dones): for idx, d in enumerate(dones1):
if d: if not d or (args.accurate and collected[idx]):
win_reason = infos['win_reason'][idx] continue
pl = 1 if to_play[idx] == main_player[idx] else -1 collected[idx] = True
episode_length = infos['l'][idx] win_reason = infos1['win_reason'][idx]
episode_reward = infos['r'][idx] pl = 1 if main[idx] else -1
episode_length = infos1['l'][idx]
episode_reward = infos1['r'][idx]
main_reward = episode_reward * pl main_reward = episode_reward * pl
win = int(main_reward > 0) win = int(main_reward > 0)
win_player = 0 if (to_play[idx] == 0 and episode_reward > 0) or (to_play[idx] == 1 and episode_reward < 0) else 1 win_player = 0 if (to_play1[idx] == 0 and episode_reward > 0) or (to_play1[idx] == 1 and episode_reward < 0) else 1
win_players.append(win_player) win_players.append(win_player)
win_agent = 1 if main_reward > 0 else 2 win_agent = 1 if main_reward > 0 else 2
win_agents.append(win_agent) win_agents.append(win_agent)
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(main_reward) episode_rewards.append(main_reward)
win_rates.append(win) win_rates.append(win)
...@@ -269,6 +337,8 @@ if __name__ == "__main__": ...@@ -269,6 +337,8 @@ if __name__ == "__main__":
# Only when num_envs=1, we switch the player here # Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
main_player = 1 - main_player main_player = 1 - main_player
else:
main_player[idx] = 1 - main_player[idx]
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
...@@ -277,14 +347,17 @@ if __name__ == "__main__": ...@@ -277,14 +347,17 @@ if __name__ == "__main__":
pbar.close() pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}") print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
episode_lengths = np.array(episode_lengths)
win_players = np.array(win_players) win_players = np.array(win_players)
win_agents = np.array(win_agents) win_agents = np.array(win_agents)
N = len(win_players) N = len(win_players)
N1 = np.sum((win_players == 0) & (win_agents == 1)) mask1 = (win_players == 0) & (win_agents == 1)
N2 = np.sum((win_players == 0) & (win_agents == 2)) mask2 = (win_players == 0) & (win_agents == 2)
N3 = np.sum((win_players == 1) & (win_agents == 1)) mask3 = (win_players == 1) & (win_agents == 1)
N4 = np.sum((win_players == 1) & (win_agents == 2)) mask4 = (win_players == 1) & (win_agents == 2)
N1, N2, N3, N4 = [np.sum(m) for m in [mask1, mask2, mask3, mask4]]
print(f"Payoff matrix:") print(f"Payoff matrix:")
w1 = N1 / N w1 = N1 / N
...@@ -304,6 +377,13 @@ if __name__ == "__main__": ...@@ -304,6 +377,13 @@ if __name__ == "__main__":
print(f"0 {w1:.4f} {w2:.4f}") print(f"0 {w1:.4f} {w2:.4f}")
print(f"1 {w3:.4f} {w4:.4f}") print(f"1 {w3:.4f} {w4:.4f}")
print(f"Length matrix, length of games of agentX as playerY")
l1 = np.mean(episode_lengths[mask1 | mask4])
l2 = np.mean(episode_lengths[mask2 | mask3])
print(f" agent1 agent2")
print(f"0 {l1:3.2f} {l2:3.2f}")
print(f"1 {l2:3.2f} {l1:3.2f}")
total_time = time.time() - start total_time = time.time() - start
total_steps = (step - start_step) * num_envs total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}") print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
......
...@@ -95,6 +95,8 @@ class Args: ...@@ -95,6 +95,8 @@ class Args:
"""the number of actor threads to use""" """the number of actor threads to use"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
segment_length: Optional[int] = None
"""the length of the segment for training"""
anneal_lr: bool = False anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
...@@ -247,6 +249,53 @@ def init_rnn_state(num_envs, rnn_channels): ...@@ -247,6 +249,53 @@ def init_rnn_state(num_envs, rnn_channels):
) )
def reshape_minibatch(
x, multi_step, num_minibatches, num_steps, segment_length=None, key=None):
# if segment_length is None,
# n_mb = num_minibatches
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb, num_steps * (num_envs // n_mb), ...)
# else, from (num_envs, ...) to
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // num_minibatches1
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
# (n_mb_e, num_envs // n_mb_e, ...)
if key is not None:
x = jax.random.permutation(key, x, axis=1 if multi_step else 0)
N = num_minibatches
if segment_length is None:
if multi_step:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
else:
M = segment_length
Nt = num_steps // M
Ne = N // Nt
if multi_step:
x = jnp.reshape(x, (Nt, M, Ne, -1) + x.shape[2:])
x = x.transpose(2, 0, 1, *range(3, x.ndim))
x = jnp.reshape(x, (Ne, Nt, -1) + x.shape[4:])
else:
x = jnp.reshape(x, (Ne, -1) + x.shape[1:])
return x
def reshape_batch(x, num_minibatches, num_steps, segment_length=None):
N = num_minibatches
x = jnp.reshape(x, (N, num_steps, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = jnp.reshape(x, (num_steps, -1) + x.shape[3:])
return x
def rollout( def rollout(
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
args: Args, args: Args,
...@@ -539,6 +588,8 @@ def main(): ...@@ -539,6 +588,8 @@ def main():
args.minibatch_size = args.local_minibatch_size * args.world_size args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size) args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.segment_length is not None:
assert args.num_steps % args.segment_length == 0, "num_steps must be divisible by segment_length"
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -675,29 +726,17 @@ def main(): ...@@ -675,29 +726,17 @@ def main():
else: else:
eval_params = None eval_params = None
def loss_fn( def advantage_fn(
params, rstate1, rstate2, obs, dones, next_dones, new_logits, new_values, next_dones, switch_or_mains,
switch_or_mains, actions, logits, rewards, mask, next_value): actions, logits, rewards, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0] num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs num_steps = next_dones.shape[0] // num_envs
def reshape_time_series(x): def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]) return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones)
if args.switch:
dones = dones | next_dones
new_logits, new_values = create_agent(args).apply(
params, obs, (rstate1, rstate2), dones, switch_or_mains)[1:3]
new_values = new_values.squeeze(-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (ratios - 1) - logratio
new_values_, rewards, next_dones, switch_or_mains = jax.tree.map( new_values_, rewards, next_dones, switch_or_mains = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch_or_mains), reshape_time_series, (new_values, rewards, next_dones, switch_or_mains),
...@@ -717,6 +756,15 @@ def main(): ...@@ -717,6 +756,15 @@ def main():
target_values, advantages = jax.tree.map( target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages)) lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
return target_values, advantages
def loss_fn(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None):
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (ratios - 1) - logratio
if args.norm_adv: if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8) advantages = masked_normalize(advantages, mask, eps=1e-8)
...@@ -743,7 +791,7 @@ def main(): ...@@ -743,7 +791,7 @@ def main():
if args.burn_in_steps: if args.burn_in_steps:
mask = jax.tree.map( mask = jax.tree.map(
lambda x: x.reshape(num_steps, num_envs), mask) lambda x: x.reshape(num_steps, -1), mask)
burn_in_mask = jnp.arange(num_steps) < args.burn_in_steps burn_in_mask = jnp.arange(num_steps) < args.burn_in_steps
mask = jnp.where(burn_in_mask[:, None], 0.0, mask) mask = jnp.where(burn_in_mask[:, None], 0.0, mask)
mask = jnp.reshape(mask, (-1,)) mask = jnp.reshape(mask, (-1,))
...@@ -754,7 +802,57 @@ def main(): ...@@ -754,7 +802,57 @@ def main():
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss) loss = jnp.where(jnp.isnan(loss) | jnp.isinf(loss), 0.0, loss)
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl)) return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains):
if args.switch:
dones = dones | next_dones
(rstate1, rstate2), new_logits, new_values = create_agent(args).apply(
params, obs, (rstate1, rstate2), dones, switch_or_mains)[:3]
new_values = new_values.squeeze(-1)
return (rstate1, rstate2), new_logits, new_values
def compute_advantage(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value):
new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3]
target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
return target_values, advantages
def compute_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask):
(rstate1, rstate2), new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None)
approx_kl, rstate1, rstate2 = jax.tree.map(
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def compute_advantage_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value, mask):
new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3]
target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=dones.shape[0] // next_value.shape[0])
approx_kl = jax.lax.stop_gradient(approx_kl)
return loss, (pg_loss, v_loss, ent_loss, approx_kl)
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
...@@ -785,7 +883,45 @@ def main(): ...@@ -785,7 +883,45 @@ def main():
switch = T[:, None] == (switch_steps[None, :] - 1) switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage) storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) if args.segment_length is None:
loss_grad_fn = jax.value_and_grad(compute_advantage_loss, has_aux=True)
else:
loss_grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
def compute_advantage_t(next_value):
N = args.num_minibatches // 4
def convert_data1(x: jnp.ndarray, multi_step=True):
return reshape_minibatch(x, multi_step, N, num_steps)
b_init_rstate1, b_init_rstate2, b_next_value = jax.tree.map(
partial(convert_data1, multi_step=False), (init_rstate1, init_rstate2, next_value))
b_storage = jax.tree.map(convert_data1, storage)
if args.switch:
b_switch_or_mains = convert_data1(switch)
else:
b_switch_or_mains = b_storage.mains
target_values, advantages = jax.lax.scan(
lambda x, y: (x, compute_advantage(x, *y)),
agent_state.params,
(
b_init_rstate1,
b_init_rstate2,
b_storage.obs,
b_storage.dones,
b_storage.next_dones,
b_switch_or_mains,
b_storage.actions,
b_storage.logits,
b_storage.rewards,
b_next_value,
))[1]
print(jax.tree.map(lambda x: x.shape, (b_storage.dones, target_values, advantages)))
target_values, advantages = jax.tree.map(
partial(reshape_batch, num_minibatches=N, num_steps=num_steps),
(target_values, advantages))
return target_values, advantages
def update_epoch(carry, _): def update_epoch(carry, _):
agent_state, key = carry agent_state, key = carry
...@@ -798,35 +934,50 @@ def main(): ...@@ -798,35 +934,50 @@ def main():
else: else:
next_value = jnp.where(next_main, next_value, -next_value) next_value = jnp.where(next_main, next_value, -next_value)
def convert_data(x: jnp.ndarray, num_steps): def convert_data(x: jnp.ndarray, multi_step=True):
if args.update_epochs > 1: key = subkey if args.update_epochs > 1 else None
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0) return reshape_minibatch(
N = args.num_minibatches x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, \ shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
shuffled_next_value = jax.tree.map( partial(convert_data, multi_step=False), (init_rstate1, init_rstate2))
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value)) shuffled_storage = jax.tree.map(convert_data, storage)
shuffled_storage = jax.tree.map(
partial(convert_data, num_steps=num_steps), storage)
if args.switch: if args.switch:
switch_or_mains = convert_data(switch, num_steps) switch_or_mains = convert_data(switch)
else: else:
switch_or_mains = shuffled_storage.mains switch_or_mains = shuffled_storage.mains
shuffled_mask = jnp.ones_like(shuffled_storage.mains) shuffled_mask = ~shuffled_storage.dones
if args.segment_length is None:
shuffled_next_value = convert_data(next_value, multi_step=False)
others = shuffled_storage.rewards, shuffled_next_value, shuffled_mask
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn( (loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch) agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices") grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads) agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
else:
target_values, advantages = compute_advantage_t(next_value)
shuffled_target_values, shuffled_advantages = jax.tree.map(
convert_data, (target_values, advantages))
others = shuffled_target_values, shuffled_advantages, shuffled_mask
def update_minibatch(agent_state, minibatch):
def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry
minibatch_t = rstate1, rstate2, *minibatch_t
(loss, (pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \
grads = loss_grad_fn(agent_state.params, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t = minibatch
(agent_state, _rstate1, _rstate2), \
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (agent_state, rstate1, rstate2), minibatch_t)
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan( agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch, update_minibatch,
...@@ -840,9 +991,7 @@ def main(): ...@@ -840,9 +991,7 @@ def main():
switch_or_mains, switch_or_mains,
shuffled_storage.actions, shuffled_storage.actions,
shuffled_storage.logits, shuffled_storage.logits,
shuffled_storage.rewards, *others,
shuffled_mask,
shuffled_next_value,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
......
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import os import os
import random import random
from typing import Optional, Literal from typing import Optional, Literal
from dataclasses import dataclass from dataclasses import dataclass, field, asdict
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -12,6 +12,7 @@ import tyro ...@@ -12,6 +12,7 @@ import tyro
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
@dataclass @dataclass
...@@ -57,14 +58,8 @@ class Args: ...@@ -57,14 +58,8 @@ class Args:
strategy: Literal["random", "greedy"] = "greedy" strategy: Literal["random", "greedy"] = "greedy"
"""the strategy to use if agent is not used""" """the strategy to use if agent is not used"""
num_layers: int = 2 m: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of layers for the agent""" """the model arguments for the agent1"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the checkpoint to load, must be a `flax_model` file""" """the checkpoint to load, must be a `flax_model` file"""
...@@ -78,11 +73,8 @@ class Args: ...@@ -78,11 +73,8 @@ class Args:
def create_agent(args): def create_agent(args):
return RNNAgent( return RNNAgent(
channels=args.num_channels, **asdict(args.m),
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
rnn_type=args.rnn_type,
) )
...@@ -97,12 +89,14 @@ if __name__ == "__main__": ...@@ -97,12 +89,14 @@ if __name__ == "__main__":
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs) args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file) deck, deck_names = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file, return_deck_names=True)
args.deck1 = args.deck1 or deck args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck args.deck2 = args.deck2 or deck
seed = args.seed seed = args.seed + 100000
random.seed(seed)
seed = random.randint(0, 1e8)
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -135,22 +129,21 @@ if __name__ == "__main__": ...@@ -135,22 +129,21 @@ if __name__ == "__main__":
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax import flax
from ygoai.rl.jax.agent import RNNAgent
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax")) cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
agent = create_agent(args) agent = create_agent(args)
key = jax.random.PRNGKey(args.seed) key = jax.random.PRNGKey(seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
rstate = agent.init_rnn_state(1) rstate = agent.init_rnn_state(1)
params = jax.jit(agent.init)(agent_key, sample_obs, rstate) params = jax.jit(agent.init)(key, sample_obs, rstate)
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
params = jax.device_put(params) params = jax.device_put(params)
rstate = agent.init_rnn_state(num_envs)
@jax.jit @jax.jit
def get_probs_and_value(params, rstate, obs, done): def get_probs_and_value(params, rstate, obs, done):
...@@ -180,6 +173,10 @@ if __name__ == "__main__": ...@@ -180,6 +173,10 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
deck_names = sorted(deck_names)
deck_times = {name: 0 for name in deck_names}
deck_time_count = {name: 0 for name in deck_names}
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1): if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
...@@ -211,7 +208,20 @@ if __name__ == "__main__": ...@@ -211,7 +208,20 @@ if __name__ == "__main__":
step += 1 step += 1
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if d: if not d:
continue
for i in range(2):
deck_time = infos['step_time'][idx][i]
deck_name = deck_names[infos['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_times[deck_name] = avg_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % 100 == 0:
print(f"Deck {deck_name}: {avg_time:.4f}")
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] episode_reward = infos['r'][idx]
......
import os
import shutil
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, ach_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
debug: bool = False
"""whether to run the script in debug mode"""
tb_dir: str = "runs"
"""the directory to save the tensorboard logs"""
ckpt_dir: str = "checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket: Optional[str] = None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = 3.0
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold: Optional[float] = None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 1.0
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = False
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: Optional[bool] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
switch=True,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue,
writer,
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
local_seed = args.seed + device_thread_id
np.random.seed(local_seed)
envs = make_env(
args,
local_seed,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
local_seed,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
rstate, logits = get_logits(params, inputs)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs))
next_rstate2, logits2 = get_logits(params2, (rstate2, obs))
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs))
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
rstate1, rstate2 = jax.tree.map(
lambda x: jnp.where(done[:, None], 0, x), (rstate1, rstate2))
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=cached_main,
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
storage = []
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
else:
eval_stats = None
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
eval_stats,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
global_main_devices = [
global_devices[process_index * len(local_devices)]
for process_index in range(args.world_size)
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
tb_log_dir = f"{args.tb_dir}/{run_name}"
if args.local_rank == 0 and not args.debug:
writer = SummaryWriter(tb_log_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
else:
writer = dummy_writer
def save_fn(obj, path):
with open(path, "wb") as f:
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=2)
# seeding
seed_offset = args.local_rank * 10000
args.seed += seed_offset
random.seed(args.seed)
init_key = jax.random.PRNGKey(args.seed - seed_offset)
key = jax.random.PRNGKey(args.seed)
key, *learner_keys = jax.random.split(key, len(learner_devices) + 1)
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(init_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
tx = optax.apply_if_finite(tx, max_consecutive_errors=10)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def loss_fn(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
dones = dones | next_dones
inputs = (rstate1, rstate2, obs, dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
new_values_, rewards, next_dones, switch = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch),
)
target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
elif args.logits_threshold is not None:
pg_loss = ach_loss(
actions, logits, new_logits, advantages, args.logits_threshold, args.clip_coef, args.dual_clip_coef)
else:
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask)
ent_loss = entropy_loss(new_logits)
ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main = jnp.concatenate(sharded_next_main)
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
next_value = jnp.where(next_main, -next_value, next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
all_reduce_value = jax.pmap(
lambda x: jax.lax.pmean(x, axis_name="main_devices"),
axis_name="main_devices",
devices=global_main_devices,
)
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
params_queues = []
rollout_queues = []
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
actor_thread_id = d_idx * args.num_actor_threads + thread_id
threading.Thread(
target=rollout,
args=(
jax.device_put(actor_keys[actor_thread_id], local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
actor_thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
eval_stat_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
if eval_stats is not None:
eval_stat_list.append(eval_stats)
if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0)
eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
f"{global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
shutil.copyfile(lastest_path, copy_path)
zip_file_path = "latest.zip"
zip_files(zip_file_path, [str(copy_path), tb_log_dir])
sync_to_gcs(args.gcs_bucket, zip_file_path)
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment