Commit 831e92ff authored by sbl1996@126.com's avatar sbl1996@126.com

Add sep_value

parent e03d45b6
......@@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field, asdict
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from typing import List, NamedTuple, Optional, Literal
from functools import partial
import ygoenv
......@@ -28,7 +28,8 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
ach_loss, policy_gradient_loss, vtrace, vtrace_2p0s, truncated_gae
from ygoai.rl.jax.switch import truncated_gae_2p0s as gae_2p0s_switch
......@@ -116,6 +117,10 @@ class Args:
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
sep_value: bool = True
"""Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace"
"""the method to learn the value function"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
c_clip_min: float = 0.001
......@@ -738,13 +743,23 @@ def main():
# Advantages and target values
if args.switch:
if args.value == "vtrace" or args.sep_value:
raise NotImplementedError
target_values, advantages = gae_2p0s_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios)
target_values, advantages = vtrace_2p0s(
if args.value == "gae":
if not args.sep_value:
raise NotImplementedError
target_values, advantages = truncated_gae(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
vtrace_fn = vtrace if args.sep_value else vtrace_2p0s
target_values, advantages = vtrace_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
......
......@@ -7,68 +7,6 @@ import chex
import distrax
# class VTraceOutput(NamedTuple):
# q_estimate: jnp.ndarray
# errors: jnp.ndarray
# def vtrace(
# v_tm1,
# v_t,
# r_t,
# discount_t,
# rho_tm1,
# lambda_=1.0,
# c_clip_min: float = 0.001,
# c_clip_max: float = 1.007,
# rho_clip_min: float = 0.001,
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
# """
# Args:
# v_tm1: values at time t-1.
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def entropy_loss(logits):
return distrax.Softmax(logits=logits).entropy()
......@@ -255,6 +193,57 @@ def vtrace_rnad(
return targets, q_estimate
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp
v = jnp.where(next_done, 0, v)
next_value = jnp.where(next_done, 0, next_value)
sign = jnp.where(main == next_main, 1, -1)
v = v * sign
next_value = next_value * sign
discount = gamma * (1.0 - next_done)
q_t = reward + discount * v
rho_t = jnp.clip(ratio, rho_min, rho_max)
c_t = jnp.clip(ratio, c_min, c_max)
sig_v = rho_t * (reward + discount * next_value - cur_value)
v = cur_value + sig_v + c_t * discount * (v - next_value)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
last_return = last_return * sign
next_q = next_q * sign
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main
return carry, (v, q_t, last_return)
def vtrace(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
):
v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = v, next_value, last_return, next_q, next_main
_, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
advantages = q_estimate - values
if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
......@@ -412,24 +401,46 @@ def truncated_gae_2p0s(
def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value = carry
cur_value, next_done, reward = inp
nextnonterminal = 1.0 - next_done
lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward, main = inp
lastgaelam = jnp.where(next_done, 0, lastgaelam)
next_value = jnp.where(next_done, 0, next_value)
delta = reward + gamma * next_value * nextnonterminal - cur_value
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
carry = lastgaelam, cur_value
return carry, lastgaelam
sign = jnp.where(main == next_main, 1, -1)
lastgaelam = lastgaelam * sign
next_value = next_value * sign
discount = gamma * (1.0 - next_done)
delta = reward + discount * next_value - cur_value
lastgaelam = delta + discount * gae_lambda * lastgaelam
# UPGO advantage
last_return = last_return * sign
next_q = next_q * sign
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
def truncated_gae(next_value, values, rewards, next_dones, gamma, gae_lambda):
carry = lastgaelam, cur_value, last_return, next_q, main
return carry, (lastgaelam, last_return)
def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False):
lastgaelam = jnp.zeros_like(next_value)
carry = lastgaelam, next_value
_, advantages = jax.lax.scan(
last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = lastgaelam, next_value, last_return, next_q, next_main
_, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards), reverse=True
carry, (values, next_dones, rewards, mains), reverse=True
)
targets = values + advantages
if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets)
return targets, advantages
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment