from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, ach_loss
from ygoai.rl.jax.switch import truncated_gae_2p0s
......@@ -98,6 +98,8 @@ class Args:
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold: Optional[float] = None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
......@@ -635,6 +637,9 @@ if __name__ == "__main__":
if args.spo_kld_max is not None:
pg_loss = simple_policy_loss(
ratios, logits, new_logits, advantages, args.spo_kld_max)
elif args.logits_threshold is not None:
pg_loss = ach_loss(
actions, logits, new_logits, advantages, args.logits_threshold, args.clip_coef, args.dual_clip_coef)
pg_loss = clipped_surrogate_pg_loss(
ratios, advantages, args.clip_coef, args.dual_clip_coef)
......@@ -100,6 +100,64 @@ def clipped_surrogate_pg_loss(ratios, advantages, clip_coef, dual_clip_coef=None
return pg_loss
def get_from_action(values, action):
num_categories = values.shape[-1]
value_one_hot = jax.nn.one_hot(
action, num_categories, dtype=values.dtype)
return jnp.sum(distrax.multiply_no_nan(values, value_one_hot), axis=-1)
def mean_legal(values, axis=None):
# TODO: use real action mask
no_nan_mask = values > -1e12
no_nan = jnp.where(no_nan_mask, values, 0)
count = jnp.sum(no_nan_mask, axis=axis)
return jnp.sum(no_nan, axis=axis) / jnp.maximum(count, 1)
def neurd_loss(actions, logits, new_logits, advantages, logits_threshold):
# Neural Replicator Dynamics
# Differences from the original implementation:
# - all actions vs. sampled actions
# - original computes advantages with q values
# - original does not use importance sampling ratios
advs = jax.lax.stop_gradient(advantages)
probs_a = get_from_action(jax.nn.softmax(logits), actions)
probs_a = jnp.maximum(probs_a, 0.001)
new_logits_a = get_from_action(new_logits, actions)
new_logits_a_ = new_logits_a - mean_legal(new_logits, axis=-1)
can_decrease_1 = new_logits_a_ < logits_threshold
can_decrease_2 = new_logits_a_ > -logits_threshold
c = jnp.where(
advs >= 0, can_decrease_1, can_decrease_2).astype(jnp.float32)
c = jax.lax.stop_gradient(c)
pg_loss = -c * new_logits_a_ / probs_a * advs
return pg_loss
def ach_loss(actions, logits, new_logits, advantages, logits_threshold, clip_coef, dual_clip_coef=None):
# Actor-Critic Hedge loss from Actor-Critic Policy Optimization in a Large-Scale Imperfect-Information Game
# notice entropy term is required but not included here
advs = jax.lax.stop_gradient(advantages)
probs_a = get_from_action(jax.nn.softmax(logits), actions)
probs_a = jnp.maximum(probs_a, 0.001)
new_logits_a = get_from_action(new_logits, actions)
new_logits_a_ = new_logits_a - mean_legal(new_logits, axis=-1)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
can_decrease_1 = (ratios < 1 + clip_coef) * (new_logits_a_ < logits_threshold)
can_decrease_2 = (ratios > 1 - clip_coef) * (new_logits_a_ > -logits_threshold)
if dual_clip_coef is not None:
can_decrease_2 = can_decrease_2 * (ratios < dual_clip_coef)
c = jnp.where(
advs >= 0, can_decrease_1, can_decrease_2).astype(jnp.float32)
c = jax.lax.stop_gradient(c)
pg_loss = -c * new_logits_a_ / probs_a * advs
return pg_loss
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
