Commit 43ca871e authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor switch

parent 93bc3723
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import random import random
from typing import Optional, Literal from typing import Optional, Literal
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm
import ygoenv import ygoenv
import numpy as np import numpy as np
...@@ -220,6 +221,9 @@ if __name__ == "__main__": ...@@ -220,6 +221,9 @@ if __name__ == "__main__":
]) ])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels) rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0 model_time = env_time = 0
while True: while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1): if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
...@@ -255,7 +259,11 @@ if __name__ == "__main__": ...@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards.append(episode_reward) episode_rewards.append(episode_reward)
win_rates.append(win) win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n") if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here # Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
...@@ -264,6 +272,8 @@ if __name__ == "__main__": ...@@ -264,6 +272,8 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}") print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start total_time = time.time() - start
......
...@@ -16,18 +16,17 @@ import jax ...@@ -16,18 +16,17 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import optax import optax
import rlax
import distrax import distrax
import tyro import tyro
from flax.training.train_state import TrainState from flax.training.train_state import TrainState
from rich.pretty import pprint from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import upgo_return, vtrace, clipped_surrogate_pg_loss from ygoai.rl.jax import vtrace_2p0s, clipped_surrogate_pg_loss, policy_gradient_loss, mse_loss, entropy_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...@@ -63,10 +62,12 @@ class Args: ...@@ -63,10 +62,12 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 32 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000 total_timesteps: int = 5000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 1e-4 learning_rate: float = 1e-3
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
local_num_envs: int = 128 local_num_envs: int = 128
"""the number of parallel game environments""" """the number of parallel game environments"""
...@@ -74,15 +75,15 @@ class Args: ...@@ -74,15 +75,15 @@ class Args:
"""the number of threads to use for environment""" """the number of threads to use for environment"""
num_actor_threads: int = 2 num_actor_threads: int = 2
"""the number of actor threads to use""" """the number of actor threads to use"""
num_steps: int = 32 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
num_minibatches: int = 4 upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
...@@ -94,12 +95,12 @@ class Args: ...@@ -94,12 +95,12 @@ class Args:
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007 rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping""" """the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping""" """whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25 clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient""" """the PPO surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = None
"""the dual surrogate clipping coefficient"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
...@@ -122,11 +123,13 @@ class Args: ...@@ -122,11 +123,13 @@ class Args:
"""whether to use `jax.distirbuted`""" """whether to use `jax.distirbuted`"""
concurrency: bool = True concurrency: bool = True
"""whether to run the actor and learner concurrently""" """whether to run the actor and learner concurrently"""
bfloat16: bool = True bfloat16: bool = False
"""whether to use bfloat16 for the agent""" """whether to use bfloat16 for the agent"""
thread_affinity: bool = False thread_affinity: bool = False
"""whether to use thread affinity for the environment""" """whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32 local_eval_episodes: int = 32
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 50
...@@ -145,6 +148,7 @@ class Args: ...@@ -145,6 +148,7 @@ class Args:
actor_devices: Optional[List[str]] = None actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1): def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
...@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
async_reset=False, async_reset=False,
greedy_reward=args.greedy_reward if mode == 'self' else True,
play_mode=mode, play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
...@@ -177,7 +182,6 @@ class Transition(NamedTuple): ...@@ -177,7 +182,6 @@ class Transition(NamedTuple):
logits: list logits: list
rewards: list rewards: list
mains: list mains: list
next_dones: list
def create_agent(args, multi_step=False): def create_agent(args, multi_step=False):
...@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False): ...@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False):
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels, lstm_channels=args.rnn_channels,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id,
) )
...@@ -209,6 +214,10 @@ def rollout( ...@@ -209,6 +214,10 @@ def rollout(
learner_devices, learner_devices,
device_thread_id, device_thread_id,
): ):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env( envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
...@@ -222,7 +231,7 @@ def rollout( ...@@ -222,7 +231,7 @@ def rollout(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot') args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -249,6 +258,17 @@ def rollout( ...@@ -249,6 +258,17 @@ def rollout(
rstate, logits = get_logits(params, inputs, done) rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1) return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
...@@ -281,7 +301,6 @@ def rollout( ...@@ -281,7 +301,6 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(main_player) np.random.shuffle(main_player)
start_step = 0
storage = [] storage = []
@jax.jit @jax.jit
...@@ -312,7 +331,7 @@ def rollout( ...@@ -312,7 +331,7 @@ def rollout(
rollout_time_start = time.time() rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map( init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2)) lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length): for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs cached_next_obs = next_obs
...@@ -340,7 +359,6 @@ def rollout( ...@@ -340,7 +359,6 @@ def rollout(
actions=action, actions=action,
logits=logits, logits=logits,
rewards=next_reward, rewards=next_reward,
next_dones=next_done,
) )
) )
...@@ -348,15 +366,6 @@ def rollout( ...@@ -348,15 +366,6 @@ def rollout(
if not d: if not d:
continue continue
cur_main = main[idx] cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1) episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
...@@ -364,10 +373,8 @@ def rollout( ...@@ -364,10 +373,8 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start) rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage) partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:] storage = []
sharded_storage = [] sharded_storage = []
for x in partitioned_storage: for x in partitioned_storage:
if isinstance(x, dict): if isinstance(x, dict):
...@@ -384,7 +391,7 @@ def rollout( ...@@ -384,7 +391,7 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
learn_opponent = False learn_opponent = False
payload = ( payload = (
global_step, global_step,
...@@ -403,10 +410,13 @@ def rollout( ...@@ -403,10 +410,13 @@ def rollout(
SPS_update = int(args.batch_size / (time.time() - update_time_start)) SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0: if device_thread_id == 0:
print( print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}" f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
) )
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S") time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}") print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step) writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step) writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
...@@ -419,19 +429,28 @@ def rollout( ...@@ -419,19 +429,28 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0] if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0: if device_thread_id != 0:
eval_queue.put(eval_return) eval_queue.put(eval_stat)
else: else:
eval_stats = [] eval_stats = []
eval_stats.append(eval_return) eval_stats.append(eval_stat)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(eval_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step) writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0: if device_thread_id == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}") print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time other_time += eval_time
...@@ -461,8 +480,15 @@ if __name__ == "__main__": ...@@ -461,8 +480,15 @@ if __name__ == "__main__":
args.minibatch_size = args.local_minibatch_size * args.world_size args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size) args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps" if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices() local_devices = jax.local_devices()
global_devices = jax.devices() global_devices = jax.devices()
...@@ -517,6 +543,13 @@ if __name__ == "__main__": ...@@ -517,6 +543,13 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels) rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs)) params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps( tx = optax.MultiSteps(
optax.chain( optax.chain(
optax.clip_by_global_norm(args.max_grad_norm), optax.clip_by_global_norm(args.max_grad_norm),
...@@ -541,6 +574,13 @@ if __name__ == "__main__": ...@@ -541,6 +574,13 @@ if __name__ == "__main__":
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit @jax.jit
def get_logits_and_value( def get_logits_and_value(
params: flax.core.FrozenDict, inputs, params: flax.core.FrozenDict, inputs,
...@@ -550,67 +590,54 @@ if __name__ == "__main__": ...@@ -550,67 +590,54 @@ if __name__ == "__main__":
return logits, value.squeeze(-1) return logits, value.squeeze(-1)
def ppo_loss( def ppo_loss(
params, rstate1, rstate2, obs, dones, next_dones, params, rstate1, rstate2, obs, dones, mains,
switch, actions, logits, rewards, mask, next_value): actions, logits, rewards, mask, next_value, next_done):
# (num_steps * local_num_envs // n_mb)) # (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0] num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs num_steps = dones.shape[0] // num_envs
mask = mask & (~dones) mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask) n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch) inputs = (rstate1, rstate2, obs, dones, mains)
new_logits, new_values = get_logits_and_value(params, inputs) new_logits, new_values = get_logits_and_value(params, inputs)
new_logits, v_tm1, logits, actions, rewards, next_dones, switch, mask = jax.tree.map( new_logits, new_values, logits, actions, rewards, dones, mains, mask = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]), lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_logits, new_values, logits, actions, rewards, next_dones, switch, mask), (new_logits, new_values, logits, actions, rewards, dones, mains, mask),
) )
next_dones = jnp.concatenate([dones[1:], next_done[None, :]], axis=0)
v_t = jnp.concatenate([v_tm1[1:], next_value[None, :]], axis=0) ratios = distrax.importance_sampling_ratios(distrax.Categorical(
discounts = (1.0 - next_dones) * args.gamma
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids # TODO: TD(lambda) for multi-step
target_values, advantages = vtrace_2p0s(
# TODO: use switch to calculate the correct value next_value, ratios, new_values, rewards, next_dones, mains, args.gamma,
vtrace_fn = partial( args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max) logratio = jnp.log(ratios)
vtrace_returns = jax.vmap( approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, ratio)
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
else:
advs = vtrace_returns.q_estimate - v_tm1
if args.ppo_clip: if args.ppo_clip:
pg_loss = jax.vmap( pg_loss = clipped_surrogate_pg_loss(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)( ratios, advantages, args.clip_coef, args.dual_clip_coef)
ratio, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
else: else:
pg_advs = jnp.minimum(args.rho_clip_max, ratio) * advs pg_advs = jnp.clip(ratios, args.rho_clip_min, args.rho_clip_max) * advantages
pg_loss = jax.vmap( pg_loss = policy_gradient_loss(new_logits, actions, pg_advs)
rlax.policy_gradient_loss, in_axes=1)( pg_loss = jnp.sum(pg_loss * mask)
new_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
v_loss = 0.5 * (vtrace_returns.errors ** 2) v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask) v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy() ent_loss = entropy_loss(new_logits)
entropy_loss = jnp.sum(entropy_loss * mask) ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
...@@ -618,6 +645,7 @@ if __name__ == "__main__": ...@@ -618,6 +645,7 @@ if __name__ == "__main__":
sharded_init_rstate1: List, sharded_init_rstate1: List,
sharded_init_rstate2: List, sharded_init_rstate2: List,
sharded_next_inputs: List, sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False, learn_opponent: bool = False,
...@@ -627,20 +655,13 @@ if __name__ == "__main__": ...@@ -627,20 +655,13 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
] ]
next_main, = [ next_main, next_done = [
jnp.concatenate(x) for x in [sharded_next_main] jnp.concatenate(x) for x in [sharded_next_main, sharded_next_done]
] ]
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second # main first, opponent second
num_steps, num_envs = storage.rewards.shape num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
...@@ -650,9 +671,7 @@ if __name__ == "__main__": ...@@ -650,9 +671,7 @@ if __name__ == "__main__":
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1) agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct next_value = jnp.where(next_main, next_value, -next_value)
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps): def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1: if args.update_epochs > 1:
...@@ -666,10 +685,11 @@ if __name__ == "__main__": ...@@ -666,10 +685,11 @@ if __name__ == "__main__":
x = jnp.reshape(x, (N, -1) + x.shape[1:]) x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map( shuffled_init_rstate1, shuffled_init_rstate2, \
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value)) shuffled_next_value, shuffled_next_done = jax.tree.map(
shuffled_storage, shuffled_switch = jax.tree.map( partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value, next_done))
partial(convert_data, num_steps=num_steps), (storage, switch)) shuffled_storage = jax.tree.map(
partial(convert_data, num_steps=num_steps), storage)
shuffled_mask = jnp.ones_like(shuffled_storage.mains) shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
...@@ -687,13 +707,13 @@ if __name__ == "__main__": ...@@ -687,13 +707,13 @@ if __name__ == "__main__":
shuffled_init_rstate2, shuffled_init_rstate2,
shuffled_storage.obs, shuffled_storage.obs,
shuffled_storage.dones, shuffled_storage.dones,
shuffled_storage.next_dones, shuffled_storage.mains,
shuffled_switch,
shuffled_storage.actions, shuffled_storage.actions,
shuffled_storage.logits, shuffled_storage.logits,
shuffled_storage.rewards, shuffled_storage.rewards,
shuffled_mask, shuffled_mask,
shuffled_next_value, shuffled_next_value,
shuffled_next_done,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
...@@ -712,7 +732,7 @@ if __name__ == "__main__": ...@@ -712,7 +732,7 @@ if __name__ == "__main__":
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(7,), static_broadcasted_argnums=(8,),
) )
params_queues = [] params_queues = []
...@@ -727,7 +747,9 @@ if __name__ == "__main__": ...@@ -727,7 +747,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1)) params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params) if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread( threading.Thread(
target=rollout, target=rollout,
args=( args=(
...@@ -741,6 +763,7 @@ if __name__ == "__main__": ...@@ -741,6 +763,7 @@ if __name__ == "__main__":
d_idx * args.num_actor_threads + thread_id, d_idx * args.num_actor_threads + thread_id,
), ),
).start() ).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10) rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10) data_transfer_time = deque(maxlen=10)
...@@ -790,8 +813,9 @@ if __name__ == "__main__": ...@@ -790,8 +813,9 @@ if __name__ == "__main__":
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step) writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step) writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print( print(
global_step, f"{global_step} actor_update={update}, "
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}", f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
) )
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
......
...@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState ...@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState
from rich.pretty import pprint from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_2p0s, upgo_advantage from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...@@ -77,8 +78,6 @@ class Args: ...@@ -77,8 +78,6 @@ class Args:
"""the number of actor threads to use""" """the number of actor threads to use"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
...@@ -95,8 +94,10 @@ class Args: ...@@ -95,8 +94,10 @@ class Args:
"""Toggles advantages normalization""" """Toggles advantages normalization"""
clip_coef: float = 0.25 clip_coef: float = 0.25
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy""" """the maximum KLD for the SPO policy, typically 0.02"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 0.5
...@@ -144,9 +145,10 @@ class Args: ...@@ -144,9 +145,10 @@ class Args:
actor_devices: Optional[List[str]] = None actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1): def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity: if not args.thread_affinity:
thread_affinity_offset = -1 thread_affinity_offset = -1
if thread_affinity_offset >= 0: if thread_affinity_offset >= 0:
...@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off ...@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options, max_options=args.max_options,
n_history_actions=args.n_history_actions, n_history_actions=args.n_history_actions,
async_reset=False, async_reset=False,
greedy_reward=args.greedy_reward, greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode, play_mode=mode,
) )
envs.num_envs = num_envs envs.num_envs = num_envs
...@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False): ...@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False):
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels, lstm_channels=args.rnn_channels,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id,
) )
...@@ -226,7 +229,7 @@ def rollout( ...@@ -226,7 +229,7 @@ def rollout(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode) args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -296,7 +299,6 @@ def rollout( ...@@ -296,7 +299,6 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(main_player) np.random.shuffle(main_player)
start_step = 0
storage = [] storage = []
@jax.jit @jax.jit
...@@ -327,7 +329,7 @@ def rollout( ...@@ -327,7 +329,7 @@ def rollout(
rollout_time_start = time.time() rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map( init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2)) lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length): for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs cached_next_obs = next_obs
...@@ -379,10 +381,8 @@ def rollout( ...@@ -379,10 +381,8 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start) rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage) partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:] storage = []
sharded_storage = [] sharded_storage = []
for x in partitioned_storage: for x in partitioned_storage:
if isinstance(x, dict): if isinstance(x, dict):
...@@ -418,10 +418,13 @@ def rollout( ...@@ -418,10 +418,13 @@ def rollout(
SPS_update = int(args.batch_size / (time.time() - update_time_start)) SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0: if device_thread_id == 0:
print( print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}" f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
) )
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S") time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}") print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step) writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step) writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
...@@ -436,14 +439,13 @@ def rollout( ...@@ -436,14 +439,13 @@ def rollout(
_start = time.time() _start = time.time()
if eval_mode == 'bot': if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x) predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate( eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0] eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
metric_name = "eval_return"
else: else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x) predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle( eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2] eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
metric_name = "eval_win_rate" eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0: if device_thread_id != 0:
eval_queue.put(eval_stat) eval_queue.put(eval_stat)
else: else:
...@@ -451,12 +453,14 @@ def rollout( ...@@ -451,12 +453,14 @@ def rollout(
eval_stats.append(eval_stat) eval_stats.append(eval_stat)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(eval_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.stack(eval_stats)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step) eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0: if device_thread_id == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -485,8 +489,15 @@ if __name__ == "__main__": ...@@ -485,8 +489,15 @@ if __name__ == "__main__":
args.minibatch_size = args.local_minibatch_size * args.world_size args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size) args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps" if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices() local_devices = jax.local_devices()
global_devices = jax.devices() global_devices = jax.devices()
...@@ -541,6 +552,13 @@ if __name__ == "__main__": ...@@ -541,6 +552,13 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels) rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs)) params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps( tx = optax.MultiSteps(
optax.chain( optax.chain(
optax.clip_by_global_norm(args.max_grad_norm), optax.clip_by_global_norm(args.max_grad_norm),
...@@ -587,66 +605,53 @@ if __name__ == "__main__": ...@@ -587,66 +605,53 @@ if __name__ == "__main__":
num_envs = next_value.shape[0] num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs num_steps = dones.shape[0] // num_envs
mask = mask & (~dones) mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask) n_valids = jnp.sum(mask)
real_dones = dones | next_dones real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch) inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs) new_logits, new_values = get_logits_and_value(params, inputs)
values, rewards, next_dones, switch = jax.tree.map( new_values_, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs)), lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(jax.lax.stop_gradient(new_values), rewards, next_dones, switch), (new_values, rewards, next_dones, switch),
) )
advantages, target_values = compute_gae_2p0s( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
next_value, values, rewards, next_dones, switch,
args.gamma, args.gae_lambda)
if args.upgo:
advantages = advantages + upgo_advantage(
next_value, values, rewards, next_dones, switch, args.gamma)
advantages, target_values = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (advantages, target_values))
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv: if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8) advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss # Policy loss
if args.spo_kld_max is not None: if args.spo_kld_max is not None:
probs = jax.nn.softmax(logits) pg_loss = simple_policy_loss(
new_probs = jax.nn.softmax(new_logits) ratios, logits, new_logits, advantages, args.spo_kld_max)
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else: else:
pg_loss1 = -advantages * ratio pg_loss = clipped_surrogate_pg_loss(
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) ratios, advantages, args.clip_coef, args.dual_clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = jnp.sum(pg_loss * mask) pg_loss = jnp.sum(pg_loss * mask)
v_loss = 0.5 * ((new_values - target_values) ** 2) v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask) v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy() ent_loss = entropy_loss(new_logits)
entropy_loss = jnp.sum(entropy_loss * mask) ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
...@@ -702,7 +707,8 @@ if __name__ == "__main__": ...@@ -702,7 +707,8 @@ if __name__ == "__main__":
x = jnp.reshape(x, (N, -1) + x.shape[1:]) x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map( shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value)) partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map( shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch)) partial(convert_data, num_steps=num_steps), (storage, switch))
...@@ -829,8 +835,9 @@ if __name__ == "__main__": ...@@ -829,8 +835,9 @@ if __name__ == "__main__":
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step) writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step) writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print( print(
global_step, f"{global_step} actor_update={update}, "
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}", f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
) )
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
......
...@@ -2,340 +2,273 @@ from functools import partial ...@@ -2,340 +2,273 @@ from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
q_estimate: jnp.ndarray
errors: jnp.ndarray
def vtrace(
v_tm1,
v_t,
r_t,
discount_t,
rho_tm1,
lambda_=1.0,
c_clip_min: float = 0.001,
c_clip_max: float = 1.007,
rho_clip_min: float = 0.001,
rho_clip_max: float = 1.007,
stop_target_gradients: bool = True,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_ = jnp.ones_like(discount_t) * lambda_
c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# Compute the temporal difference errors.
td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# Work backwards computing the td-errors.
def _body(acc, xs):
td_error, discount, c = xs
acc = td_error + discount * c * acc
return acc, acc
_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors = jax.lax.select(
stop_target_gradients,
jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
errors)
targets_tm1 = errors + v_tm1
q_bootstrap = jnp.concatenate([
lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
v_t[-1:],
], axis=0)
q_estimate = r_t + discount_t * q_bootstrap
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(5, 6))
def compute_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda,
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done) import chex
next_value = jnp.where(switch, -boot_value, next_value) import distrax
lastgaelam = jnp.where(switch, 0, lastgaelam)
gamma_ = gamma * (1.0 - next_done)
delta = reward + gamma_ * next_value - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
next_done = next_dones[-1] # class VTraceOutput(NamedTuple):
lastgaelam = jnp.zeros_like(next_value) # q_estimate: jnp.ndarray
carry = next_value, next_done, next_value, lastgaelam # errors: jnp.ndarray
_, advantages = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(5,))
def upgo_advantage(
next_value, values, rewards, next_dones, switch, gamma):
def body_fn(carry, inp):
boot_value, boot_done, next_value, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
gamma_ = gamma * (1.0 - next_done)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
carry = boot_value, boot_done, cur_value, next_q, last_return
return carry, last_return
next_done = next_dones[-1]
carry = next_value, next_done, next_value, next_value, next_value
_, returns = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
return returns - values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues # def vtrace(
delta2 = reward2 + gamma * nextvalues2 - curvalues # v_tm1,
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1 # v_t,
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2 # r_t,
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_) # discount_t,
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1) # rho_tm1,
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2) # lambda_=1.0,
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1) # c_clip_min: float = 0.001,
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2) # c_clip_max: float = 1.007,
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 # rho_clip_min: float = 0.001,
return carry, advantages # rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
@partial(jax.jit, static_argnums=(7, 8)) # """
def compute_gae(
next_value, next_done, next_learn, # Args:
values, rewards, dones, learns, # v_tm1: values at time t-1.
gamma, gae_lambda, # v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def entropy_loss(logits):
return distrax.Softmax(logits=logits).entropy()
def mse_loss(y_true, y_pred):
return 0.5 * ((y_true - y_pred) ** 2)
def policy_gradient_loss(logits, actions, advantages):
chex.assert_type([logits, actions, advantages], [float, int, float])
advs = jax.lax.stop_gradient(advantages)
log_probs = distrax.Softmax(logits=logits).log_prob(actions)
pg_loss = -log_probs * advs
return pg_loss
def clipped_surrogate_pg_loss(ratios, advantages, clip_coef, dual_clip_coef=None):
# dual clip from JueWu (Mastering Complex Control in MOBA Games with Deep Reinforcement Learning)
advs = jax.lax.stop_gradient(advantages)
clipped_ratios = jnp.clip(ratios, 1 - clip_coef, 1 + clip_coef)
clipped_obj = jnp.fmin(ratios * advs, clipped_ratios * advs)
if dual_clip_coef is not None:
clipped_obj = jnp.where(
advs >= 0, clipped_obj,
jnp.fmax(clipped_obj, dual_clip_coef * advs)
)
pg_loss = -clipped_obj
return pg_loss
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1)
v2 = jnp.where(next_done, 0, v2)
next_values1 = jnp.where(next_done, 0, next_values1)
next_values2 = jnp.where(next_done, 0, next_values2)
reward1 = jnp.where(next_done, 0, reward1)
reward2 = jnp.where(next_done, 0, reward2)
xi1 = jnp.where(next_done, 1, xi1)
xi2 = jnp.where(next_done, 1, xi2)
discount = gamma * (1.0 - next_done)
v = jnp.where(main, v1, v2)
next_values = jnp.where(main, next_values1, next_values2)
reward = jnp.where(main, reward1, reward2)
xi = jnp.where(main, xi1, xi2)
q_t = r_t + ratio * reward + discount * v
rho_t = jnp.clip(ratio * xi, rho_min, rho_max)
c_t = jnp.clip(ratio * xi, c_min, c_max)
sig_v = rho_t * (r_t + ratio * reward + discount * next_values - cur_values)
v = cur_values + sig_v + c_t * discount * (v - next_values)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t = jnp.where(main, last_return1, last_return2)
next_q = jnp.where(main, next_q1, next_q2)
factor = jnp.where(main, jnp.ones_like(r_t), -jnp.ones_like(r_t))
return_t = r_t + discount * jnp.where(
next_q >= next_values, return_t, next_values)
last_return1 = jnp.where(
next_done, r_t * factor, jnp.where(main, return_t, last_return1))
last_return2 = jnp.where(
next_done, r_t * -factor, jnp.where(main, last_return2, return_t))
next_q = r_t + discount * next_values
next_q1 = jnp.where(
next_done, r_t * factor, jnp.where(main, next_q, next_q1))
next_q2 = jnp.where(
next_done, r_t * -factor, jnp.where(main, next_q2, next_q))
v1 = jnp.where(main, v, v1)
v2 = jnp.where(main, v2, v)
next_values1 = jnp.where(main, cur_values, next_values1)
next_values2 = jnp.where(main, next_values2, cur_values)
reward1 = jnp.where(main, 0, -r_t + ratio * reward1)
reward2 = jnp.where(main, -r_t + ratio * reward2, 0)
xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2
return carry, (v, q_t, return_t)
def vtrace_2p0s(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
): ):
next_value1 = jnp.where(next_learn, next_value, -next_value) next_value1 = next_value
next_value2 = -next_value1 next_value2 = -next_value1
done_used1 = jnp.ones_like(next_done) v1 = return1 = next_q1 = next_value1
done_used2 = jnp.ones_like(next_done) v2 = return2 = next_q2 = next_value2
reward1 = jnp.zeros_like(next_value) reward1 = reward2 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value) xi1 = xi2 = jnp.ones_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value) carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \
lastgaelam2 = jnp.zeros_like(next_value) return1, return2, next_q1, next_q2
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
_, (targets, q_estimate, return_t) = jax.lax.scan(
dones = jnp.concatenate([dones, next_done[None, :]], axis=0) partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
_, advantages = jax.lax.scan( carry, (ratios, values, next_dones, rewards, mains), reverse=True
partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
) )
target_values = advantages + values advantages = q_estimate - values
return advantages, target_values if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets)
def compute_gae_once_upgo(carry, inp, gamma, gae_lambda): return targets, advantages
next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
learn1 = learn lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
learn2 = ~learn done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward)) cur_value, next_done, reward, main = inp
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1)) main1 = main
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2)) main2 = ~main
factor = jnp.where(main1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(main1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(main2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1 real_done1 = next_done | ~done_used1
next_value1 = jnp.where(real_done1, 0, next_value1) next_value1 = jnp.where(real_done1, 0, next_value1)
last_return1 = jnp.where(real_done1, 0, last_return1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1) lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2 real_done2 = next_done | ~done_used2
next_value2 = jnp.where(real_done2, 0, next_value2) next_value2 = jnp.where(real_done2, 0, next_value2)
last_return2 = jnp.where(real_done2, 0, last_return2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2) lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where( done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1)) next_done, main1, jnp.where(main1 & ~done_used1, True, done_used1))
done_used2 = jnp.where( done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2)) next_done, main2, jnp.where(main2 & ~done_used2, True, done_used2))
# UPGO advantage
last_return1 = jnp.where(real_done1, 0, last_return1)
last_return2 = jnp.where(real_done2, 0, last_return2)
last_return1_ = reward1 + gamma * jnp.where( last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1) next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where( last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2) next_q2 >= next_value2, last_return2, next_value2)
next_q1_ = reward1 + gamma * next_value1 next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2 next_q2_ = reward2 + gamma * next_value2
delta1 = next_q1_ - curvalues next_q1 = jnp.where(main1, next_q1_, next_q1)
delta2 = next_q2_ - curvalues next_q2 = jnp.where(main2, next_q2_, next_q1)
last_return1 = jnp.where(main1, last_return1_, last_return1)
last_return2 = jnp.where(main2, last_return2_, last_return2)
returns = jnp.where(main1, last_return1_, last_return2_)
delta1 = next_q1_ - cur_value
delta2 = next_q2_ - cur_value
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1 lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2 lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
returns = jnp.where(learn1, last_return1_, last_return2_) advantages = jnp.where(main1, lastgaelam1_, lastgaelam2_)
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_) next_value1 = jnp.where(main1, cur_value, next_value1)
next_value1 = jnp.where(learn1, curvalues, next_value1) next_value2 = jnp.where(main2, cur_value, next_value2)
next_value2 = jnp.where(learn2, curvalues, next_value2) lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1) lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
next_q1 = jnp.where(learn1, next_q1_, next_q1) carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
next_q2 = jnp.where(learn2, next_q2_, next_q1) done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
last_return1 = jnp.where(learn1, last_return1_, last_return1)
last_return2 = jnp.where(learn2, last_return2_, last_return2)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, (advantages, returns) return carry, (advantages, returns)
@partial(jax.jit, static_argnums=(7, 8)) def truncated_gae_2p0s(
def compute_gae_upgo( next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
): ):
next_value1 = jnp.where(next_learn, next_value, -next_value) next_value1 = next_value
next_value2 = -next_value1 next_value2 = -next_value1
last_return1 = next_q1 = next_value1 last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2 last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_done) done_used1 = jnp.ones_like(next_dones[-1])
done_used2 = jnp.ones_like(next_done) done_used2 = jnp.ones_like(next_dones[-1])
reward1 = jnp.zeros_like(next_value) reward1 = reward2 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value) lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value) carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
lastgaelam2 = jnp.zeros_like(next_value) done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, (advantages, returns) = jax.lax.scan( _, (advantages, returns) = jax.lax.scan(
partial(compute_gae_once_upgo, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_upgo_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
return returns - values, advantages + values if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def simple_policy_loss(ratios, logits, new_logits, advantages, kld_max, eps=1e-12):
advs = jax.lax.stop_gradient(advantages)
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
kld = jnp.sum(
probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, kld_max)
d_ratio = kld_clip / (kld + eps)
# e == 1 and t == 1
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advs)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advs * ratios * result
return pg_loss
\ No newline at end of file
...@@ -150,6 +150,7 @@ class Encoder(nn.Module): ...@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -168,6 +169,8 @@ class Encoder(nn.Module): ...@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype) fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
if self.freeze_id:
id_embed = lambda x: jax.lax.stop_gradient(id_embed(x))
action_encoder = ActionEncoder( action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
...@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False multi_step: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
...@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module): ...@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape=self.embedding_shape, embedding_shape=self.embedding_shape,
dtype=self.dtype, dtype=self.dtype,
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
) )
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
......
import jax
import jax.numpy as jnp
def truncated_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
discount = gamma * (1.0 - next_done)
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * discount * lastgaelam
carry = boot_value, boot_done, cur_value, lastgaelam, next_q, last_return
return carry, (lastgaelam, last_return)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
next_q = last_return = next_value
carry = next_value, next_done, next_value, lastgaelam, next_q, last_return
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
...@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8): ...@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def to_tensor(x, device, dtype=None): def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
import pickle
import numpy as np
from pathlib import Path from pathlib import Path
...@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False): ...@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif 'EDOPro' in env_id: elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks) init_module(str(db_path), code_list_file, decks)
return deck_name return deck_name
\ No newline at end of file
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment