Commit 43ca871e authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor switch

parent 93bc3723
......@@ -4,6 +4,7 @@ import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
from tqdm import tqdm
import ygoenv
import numpy as np
......@@ -220,6 +221,9 @@ if __name__ == "__main__":
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
if not args.verbose:
pbar = tqdm(total=args.num_episodes)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
......@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
if args.verbose:
print(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
else:
pbar.set_postfix(len=np.mean(episode_lengths), reward=np.mean(episode_rewards), win_rate=np.mean(win_rates))
pbar.update(1)
# Only when num_envs=1, we switch the player here
if args.verbose:
......@@ -264,6 +272,8 @@ if __name__ == "__main__":
if len(episode_lengths) >= args.num_episodes:
break
if not args.verbose:
pbar.close()
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
......
......@@ -16,18 +16,17 @@ import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import upgo_return, vtrace, clipped_surrogate_pg_loss
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import vtrace_2p0s, clipped_surrogate_pg_loss, policy_gradient_loss, mse_loss, entropy_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -63,10 +62,12 @@ class Args:
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-4
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
......@@ -74,15 +75,15 @@ class Args:
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 32
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 4
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
......@@ -94,12 +95,12 @@ class Args:
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = None
"""the dual surrogate clipping coefficient"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
......@@ -122,11 +123,13 @@ class Args:
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
bfloat16: bool = False
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
......@@ -145,6 +148,7 @@ class Args:
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
......@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward if mode == 'self' else True,
play_mode=mode,
)
envs.num_envs = num_envs
......@@ -177,7 +182,6 @@ class Transition(NamedTuple):
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
......@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False):
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
......@@ -209,6 +214,10 @@ def rollout(
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
......@@ -222,7 +231,7 @@ def rollout(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
......@@ -249,6 +258,17 @@ def rollout(
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
......@@ -281,7 +301,6 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
......@@ -312,7 +331,7 @@ def rollout(
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
......@@ -340,7 +359,6 @@ def rollout(
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
......@@ -348,15 +366,6 @@ def rollout(
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
......@@ -364,10 +373,8 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
storage = []
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
......@@ -384,7 +391,7 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
learn_opponent = False
payload = (
global_step,
......@@ -403,10 +410,13 @@ def rollout(
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
......@@ -419,19 +429,28 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0:
eval_queue.put(eval_return)
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_return)
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time
......@@ -461,8 +480,15 @@ if __name__ == "__main__":
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices()
global_devices = jax.devices()
......@@ -517,6 +543,13 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
......@@ -541,6 +574,13 @@ if __name__ == "__main__":
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
......@@ -550,67 +590,54 @@ if __name__ == "__main__":
return logits, value.squeeze(-1)
def ppo_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
params, rstate1, rstate2, obs, dones, mains,
actions, logits, rewards, mask, next_value, next_done):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
mask = mask & (~dones)
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
inputs = (rstate1, rstate2, obs, dones, mains)
new_logits, new_values = get_logits_and_value(params, inputs)
new_logits, v_tm1, logits, actions, rewards, next_dones, switch, mask = jax.tree.map(
new_logits, new_values, logits, actions, rewards, dones, mains, mask = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_logits, new_values, logits, actions, rewards, next_dones, switch, mask),
(new_logits, new_values, logits, actions, rewards, dones, mains, mask),
)
next_dones = jnp.concatenate([dones[1:], next_done[None, :]], axis=0)
v_t = jnp.concatenate([v_tm1[1:], next_value[None, :]], axis=0)
discounts = (1.0 - next_dones) * args.gamma
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
# TODO: use switch to calculate the correct value
vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, ratio)
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
else:
advs = vtrace_returns.q_estimate - v_tm1
# TODO: TD(lambda) for multi-step
target_values, advantages = vtrace_2p0s(
next_value, ratios, new_values, rewards, next_dones, mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.ppo_clip:
pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
ratio, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
else:
pg_advs = jnp.minimum(args.rho_clip_max, ratio) * advs
pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)(
new_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
pg_advs = jnp.clip(ratios, args.rho_clip_min, args.rho_clip_max) * advantages
pg_loss = policy_gradient_loss(new_logits, actions, pg_advs)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = 0.5 * (vtrace_returns.errors ** 2)
v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy()
entropy_loss = jnp.sum(entropy_loss * mask)
ent_loss = entropy_loss(new_logits)
ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids
ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
......@@ -618,6 +645,7 @@ if __name__ == "__main__":
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
......@@ -627,20 +655,13 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
next_main, next_done = [
jnp.concatenate(x) for x in [sharded_next_main, sharded_next_done]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
......@@ -650,9 +671,7 @@ if __name__ == "__main__":
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
next_value = jnp.where(next_main, next_value, -next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
......@@ -666,10 +685,11 @@ if __name__ == "__main__":
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value, shuffled_next_done = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value, next_done))
shuffled_storage = jax.tree.map(
partial(convert_data, num_steps=num_steps), storage)
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
......@@ -687,13 +707,13 @@ if __name__ == "__main__":
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.mains,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
shuffled_next_done,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
......@@ -712,7 +732,7 @@ if __name__ == "__main__":
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
static_broadcasted_argnums=(8,),
)
params_queues = []
......@@ -727,7 +747,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
......@@ -741,6 +763,7 @@ if __name__ == "__main__":
d_idx * args.num_actor_threads + thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
......@@ -790,8 +813,9 @@ if __name__ == "__main__":
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
f"{global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
......
......@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_2p0s, upgo_advantage
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -77,8 +78,6 @@ class Args:
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
......@@ -95,8 +94,10 @@ class Args:
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
"""the maximum KLD for the SPO policy, typically 0.02"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
......@@ -144,9 +145,10 @@ class Args:
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: bool = False
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
......@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward,
greedy_reward=args.greedy_reward if not eval else True,
play_mode=mode,
)
envs.num_envs = num_envs
......@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False):
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
......@@ -226,7 +229,7 @@ def rollout(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode)
args.local_eval_episodes // 4, mode=eval_mode, eval=True)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
......@@ -296,7 +299,6 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
......@@ -327,7 +329,7 @@ def rollout(
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
for _ in range(args.num_steps):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
......@@ -379,10 +381,8 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
storage = []
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
......@@ -418,10 +418,13 @@ def rollout(
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
print(
f"{time_now} SPS: {SPS}, update: {SPS_update}, "
f"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
......@@ -436,14 +439,13 @@ def rollout(
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
......@@ -451,12 +453,14 @@ def rollout(
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__":
......@@ -485,8 +489,15 @@ if __name__ == "__main__":
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
embedding_shape = embeddings.shape
args.num_embeddings = embedding_shape
args.freeze_id = True if args.freeze_id is None else args.freeze_id
else:
embeddings = None
embedding_shape = None
local_devices = jax.local_devices()
global_devices = jax.devices()
......@@ -541,6 +552,13 @@ if __name__ == "__main__":
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
if embeddings is not None:
unknown_embed = embeddings.mean(axis=0)
embeddings = np.concatenate([unknown_embed[None, :], embeddings], axis=0)
params = flax.core.unfreeze(params)
params['params']['Encoder_0']['Embed_0']['embedding'] = jax.device_put(embeddings)
params = flax.core.freeze(params)
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
......@@ -587,66 +605,53 @@ if __name__ == "__main__":
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
mask = mask & (~dones)
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
values, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs)),
(jax.lax.stop_gradient(new_values), rewards, next_dones, switch),
new_values_, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_values, rewards, next_dones, switch),
)
advantages, target_values = compute_gae_2p0s(
next_value, values, rewards, next_dones, switch,
args.gamma, args.gae_lambda)
if args.upgo:
advantages = advantages + upgo_advantage(
next_value, values, rewards, next_dones, switch, args.gamma)
advantages, target_values = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (advantages, target_values))
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = 0.5 * ((new_values - target_values) ** 2)
v_loss = mse_loss(new_values, target_values)
v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy()
entropy_loss = jnp.sum(entropy_loss * mask)
ent_loss = entropy_loss(new_logits)
ent_loss = jnp.sum(ent_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids
ent_loss = ent_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
......@@ -702,7 +707,8 @@ if __name__ == "__main__":
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
......@@ -829,8 +835,9 @@ if __name__ == "__main__":
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
f"{global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
......
......@@ -2,340 +2,273 @@ from functools import partial
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
q_estimate: jnp.ndarray
errors: jnp.ndarray
def vtrace(
v_tm1,
v_t,
r_t,
discount_t,
rho_tm1,
lambda_=1.0,
c_clip_min: float = 0.001,
c_clip_max: float = 1.007,
rho_clip_min: float = 0.001,
rho_clip_max: float = 1.007,
stop_target_gradients: bool = True,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_ = jnp.ones_like(discount_t) * lambda_
c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# Compute the temporal difference errors.
td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# Work backwards computing the td-errors.
def _body(acc, xs):
td_error, discount, c = xs
acc = td_error + discount * c * acc
return acc, acc
_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors = jax.lax.select(
stop_target_gradients,
jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
errors)
targets_tm1 = errors + v_tm1
q_bootstrap = jnp.concatenate([
lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
v_t[-1:],
], axis=0)
q_estimate = r_t + discount_t * q_bootstrap
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
return -jnp.mean(clipped_objective * mask)
@partial(jax.jit, static_argnums=(5, 6))
def compute_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda,
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
import chex
import distrax
gamma_ = gamma * (1.0 - next_done)
delta = reward + gamma_ * next_value - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
return (boot_value, boot_done, cur_value, lastgaelam), lastgaelam
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, lastgaelam
# class VTraceOutput(NamedTuple):
# q_estimate: jnp.ndarray
# errors: jnp.ndarray
_, advantages = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
target_values = advantages + values
return advantages, target_values
@partial(jax.jit, static_argnums=(5,))
def upgo_advantage(
next_value, values, rewards, next_dones, switch, gamma):
def body_fn(carry, inp):
boot_value, boot_done, next_value, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
gamma_ = gamma * (1.0 - next_done)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
carry = boot_value, boot_done, cur_value, next_q, last_return
return carry, last_return
next_done = next_dones[-1]
carry = next_value, next_done, next_value, next_value, next_value
_, returns = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
return returns - values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
# def vtrace(
# v_tm1,
# v_t,
# r_t,
# discount_t,
# rho_tm1,
# lambda_=1.0,
# c_clip_min: float = 0.001,
# c_clip_max: float = 1.007,
# rho_clip_min: float = 0.001,
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
# """
# Args:
# v_tm1: values at time t-1.
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def entropy_loss(logits):
return distrax.Softmax(logits=logits).entropy()
def mse_loss(y_true, y_pred):
return 0.5 * ((y_true - y_pred) ** 2)
def policy_gradient_loss(logits, actions, advantages):
chex.assert_type([logits, actions, advantages], [float, int, float])
advs = jax.lax.stop_gradient(advantages)
log_probs = distrax.Softmax(logits=logits).log_prob(actions)
pg_loss = -log_probs * advs
return pg_loss
def clipped_surrogate_pg_loss(ratios, advantages, clip_coef, dual_clip_coef=None):
# dual clip from JueWu (Mastering Complex Control in MOBA Games with Deep Reinforcement Learning)
advs = jax.lax.stop_gradient(advantages)
clipped_ratios = jnp.clip(ratios, 1 - clip_coef, 1 + clip_coef)
clipped_obj = jnp.fmin(ratios * advs, clipped_ratios * advs)
if dual_clip_coef is not None:
clipped_obj = jnp.where(
advs >= 0, clipped_obj,
jnp.fmax(clipped_obj, dual_clip_coef * advs)
)
pg_loss = -clipped_obj
return pg_loss
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1)
v2 = jnp.where(next_done, 0, v2)
next_values1 = jnp.where(next_done, 0, next_values1)
next_values2 = jnp.where(next_done, 0, next_values2)
reward1 = jnp.where(next_done, 0, reward1)
reward2 = jnp.where(next_done, 0, reward2)
xi1 = jnp.where(next_done, 1, xi1)
xi2 = jnp.where(next_done, 1, xi2)
discount = gamma * (1.0 - next_done)
v = jnp.where(main, v1, v2)
next_values = jnp.where(main, next_values1, next_values2)
reward = jnp.where(main, reward1, reward2)
xi = jnp.where(main, xi1, xi2)
q_t = r_t + ratio * reward + discount * v
rho_t = jnp.clip(ratio * xi, rho_min, rho_max)
c_t = jnp.clip(ratio * xi, c_min, c_max)
sig_v = rho_t * (r_t + ratio * reward + discount * next_values - cur_values)
v = cur_values + sig_v + c_t * discount * (v - next_values)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t = jnp.where(main, last_return1, last_return2)
next_q = jnp.where(main, next_q1, next_q2)
factor = jnp.where(main, jnp.ones_like(r_t), -jnp.ones_like(r_t))
return_t = r_t + discount * jnp.where(
next_q >= next_values, return_t, next_values)
last_return1 = jnp.where(
next_done, r_t * factor, jnp.where(main, return_t, last_return1))
last_return2 = jnp.where(
next_done, r_t * -factor, jnp.where(main, last_return2, return_t))
next_q = r_t + discount * next_values
next_q1 = jnp.where(
next_done, r_t * factor, jnp.where(main, next_q, next_q1))
next_q2 = jnp.where(
next_done, r_t * -factor, jnp.where(main, next_q2, next_q))
v1 = jnp.where(main, v, v1)
v2 = jnp.where(main, v2, v)
next_values1 = jnp.where(main, cur_values, next_values1)
next_values2 = jnp.where(main, next_values2, cur_values)
reward1 = jnp.where(main, 0, -r_t + ratio * reward1)
reward2 = jnp.where(main, -r_t + ratio * reward2, 0)
xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2
return carry, (v, q_t, return_t)
def vtrace_2p0s(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value1 = next_value
next_value2 = -next_value1
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
v1 = return1 = next_q1 = next_value1
v2 = return2 = next_q2 = next_value2
reward1 = reward2 = jnp.zeros_like(next_value)
xi1 = xi2 = jnp.ones_like(next_value)
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \
return1, return2, next_q1, next_q2
_, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
target_values = advantages + values
return advantages, target_values
def compute_gae_once_upgo(carry, inp, gamma, gae_lambda):
next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
advantages = q_estimate - values
if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
cur_value, next_done, reward, main = inp
main1 = main
main2 = ~main
factor = jnp.where(main1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(main1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(main2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
next_value1 = jnp.where(real_done1, 0, next_value1)
last_return1 = jnp.where(real_done1, 0, last_return1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
next_value2 = jnp.where(real_done2, 0, next_value2)
last_return2 = jnp.where(real_done2, 0, last_return2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
next_done, main1, jnp.where(main1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
next_done, main2, jnp.where(main2 & ~done_used2, True, done_used2))
# UPGO advantage
last_return1 = jnp.where(real_done1, 0, last_return1)
last_return2 = jnp.where(real_done2, 0, last_return2)
last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2)
next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2
delta1 = next_q1_ - curvalues
delta2 = next_q2_ - curvalues
next_q1 = jnp.where(main1, next_q1_, next_q1)
next_q2 = jnp.where(main2, next_q2_, next_q1)
last_return1 = jnp.where(main1, last_return1_, last_return1)
last_return2 = jnp.where(main2, last_return2_, last_return2)
returns = jnp.where(main1, last_return1_, last_return2_)
delta1 = next_q1_ - cur_value
delta2 = next_q2_ - cur_value
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
returns = jnp.where(learn1, last_return1_, last_return2_)
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
next_value1 = jnp.where(learn1, curvalues, next_value1)
next_value2 = jnp.where(learn2, curvalues, next_value2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
next_q1 = jnp.where(learn1, next_q1_, next_q1)
next_q2 = jnp.where(learn2, next_q2_, next_q1)
last_return1 = jnp.where(learn1, last_return1_, last_return1)
last_return2 = jnp.where(learn2, last_return2_, last_return2)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
advantages = jnp.where(main1, lastgaelam1_, lastgaelam2_)
next_value1 = jnp.where(main1, cur_value, next_value1)
next_value2 = jnp.where(main2, cur_value, next_value2)
lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
return carry, (advantages, returns)
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae_upgo(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
def truncated_gae_2p0s(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value1 = next_value
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
done_used1 = jnp.ones_like(next_dones[-1])
done_used2 = jnp.ones_like(next_dones[-1])
reward1 = reward2 = jnp.zeros_like(next_value)
lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
_, (advantages, returns) = jax.lax.scan(
partial(compute_gae_once_upgo, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
partial(truncated_gae_upgo_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
return returns - values, advantages + values
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def simple_policy_loss(ratios, logits, new_logits, advantages, kld_max, eps=1e-12):
advs = jax.lax.stop_gradient(advantages)
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
kld = jnp.sum(
probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, kld_max)
d_ratio = kld_clip / (kld + eps)
# e == 1 and t == 1
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advs)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advs * ratios * result
return pg_loss
\ No newline at end of file
......@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False
@nn.compact
def __call__(self, x):
......@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
if self.freeze_id:
id_embed = lambda x: jax.lax.stop_gradient(id_embed(x))
action_encoder = ActionEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
switch: bool = True
freeze_id: bool = False
@nn.compact
def __call__(self, inputs):
......@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
freeze_id=self.freeze_id,
)
f_actions, f_state, mask, valid = encoder(x)
......
import jax
import jax.numpy as jnp
def truncated_gae_2p0s(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
):
def body_fn(carry, inp):
boot_value, boot_done, next_value, lastgaelam, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
lastgaelam = jnp.where(switch, 0, lastgaelam)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
discount = gamma * (1.0 - next_done)
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * discount * lastgaelam
carry = boot_value, boot_done, cur_value, lastgaelam, next_q, last_return
return carry, (lastgaelam, last_return)
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
next_q = last_return = next_value
carry = next_value, next_done, next_value, lastgaelam, next_q, last_return
_, (advantages, returns) = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
......@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
import pickle
import numpy as np
from pathlib import Path
......@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks)
return deck_name
\ No newline at end of file
return deck_name
def load_embeddings(embedding_file, code_list_file):
with open(embedding_file, "rb") as f:
embeddings = pickle.load(f)
with open(code_list_file, "r") as f:
code_list = f.readlines()
code_list = [int(code.strip()) for code in code_list]
assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32)
return embeddings
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment