Commit 831e92ff authored by sbl1996@126.com's avatar sbl1996@126.com

Add sep_value

parent e03d45b6
...@@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone ...@@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone
from collections import deque from collections import deque
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional, Literal
from functools import partial from functools import partial
import ygoenv import ygoenv
...@@ -28,7 +28,8 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files ...@@ -28,7 +28,8 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
ach_loss, policy_gradient_loss, vtrace, vtrace_2p0s, truncated_gae
from ygoai.rl.jax.switch import truncated_gae_2p0s as gae_2p0s_switch from ygoai.rl.jax.switch import truncated_gae_2p0s as gae_2p0s_switch
...@@ -116,6 +117,10 @@ class Args: ...@@ -116,6 +117,10 @@ class Args:
upgo: bool = True upgo: bool = True
"""Toggle the use of UPGO for advantages""" """Toggle the use of UPGO for advantages"""
sep_value: bool = True
"""Whether separate value function computation for each player"""
value: Literal["vtrace", "gae"] = "vtrace"
"""the method to learn the value function"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""the lambda for the general advantage estimation""" """the lambda for the general advantage estimation"""
c_clip_min: float = 0.001 c_clip_min: float = 0.001
...@@ -738,15 +743,25 @@ def main(): ...@@ -738,15 +743,25 @@ def main():
# Advantages and target values # Advantages and target values
if args.switch: if args.switch:
if args.value == "vtrace" or args.sep_value:
raise NotImplementedError
target_values, advantages = gae_2p0s_switch( target_values, advantages = gae_2p0s_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains, next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo) args.gamma, args.gae_lambda, args.upgo)
else: else:
# TODO: TD(lambda) for multi-step # TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios) ratios_ = reshape_time_series(ratios)
target_values, advantages = vtrace_2p0s( if args.value == "gae":
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma, if not args.sep_value:
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max) raise NotImplementedError
target_values, advantages = truncated_gae(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
vtrace_fn = vtrace if args.sep_value else vtrace_2p0s
target_values, advantages = vtrace_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
target_values, advantages = jax.tree.map( target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages)) lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
......
...@@ -7,68 +7,6 @@ import chex ...@@ -7,68 +7,6 @@ import chex
import distrax import distrax
# class VTraceOutput(NamedTuple):
# q_estimate: jnp.ndarray
# errors: jnp.ndarray
# def vtrace(
# v_tm1,
# v_t,
# r_t,
# discount_t,
# rho_tm1,
# lambda_=1.0,
# c_clip_min: float = 0.001,
# c_clip_max: float = 1.007,
# rho_clip_min: float = 0.001,
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
# """
# Args:
# v_tm1: values at time t-1.
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def entropy_loss(logits): def entropy_loss(logits):
return distrax.Softmax(logits=logits).entropy() return distrax.Softmax(logits=logits).entropy()
...@@ -255,6 +193,57 @@ def vtrace_rnad( ...@@ -255,6 +193,57 @@ def vtrace_rnad(
return targets, q_estimate return targets, q_estimate
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp
v = jnp.where(next_done, 0, v)
next_value = jnp.where(next_done, 0, next_value)
sign = jnp.where(main == next_main, 1, -1)
v = v * sign
next_value = next_value * sign
discount = gamma * (1.0 - next_done)
q_t = reward + discount * v
rho_t = jnp.clip(ratio, rho_min, rho_max)
c_t = jnp.clip(ratio, c_min, c_max)
sig_v = rho_t * (reward + discount * next_value - cur_value)
v = cur_value + sig_v + c_t * discount * (v - next_value)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
last_return = last_return * sign
next_q = next_q * sign
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main
return carry, (v, q_t, last_return)
def vtrace(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
):
v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = v, next_value, last_return, next_q, next_main
_, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
advantages = q_estimate - values
if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets)
return targets, advantages
def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry last_return1, last_return2, next_q1, next_q2 = carry
...@@ -412,24 +401,46 @@ def truncated_gae_2p0s( ...@@ -412,24 +401,46 @@ def truncated_gae_2p0s(
def truncated_gae_loop(carry, inp, gamma, gae_lambda): def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value = carry lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward = inp cur_value, next_done, reward, main = inp
nextnonterminal = 1.0 - next_done
lastgaelam = jnp.where(next_done, 0, lastgaelam)
next_value = jnp.where(next_done, 0, next_value)
delta = reward + gamma * next_value * nextnonterminal - cur_value sign = jnp.where(main == next_main, 1, -1)
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam lastgaelam = lastgaelam * sign
carry = lastgaelam, cur_value next_value = next_value * sign
return carry, lastgaelam
discount = gamma * (1.0 - next_done)
delta = reward + discount * next_value - cur_value
lastgaelam = delta + discount * gae_lambda * lastgaelam
# UPGO advantage
last_return = last_return * sign
next_q = next_q * sign
last_return = reward + discount * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + discount * next_value
def truncated_gae(next_value, values, rewards, next_dones, gamma, gae_lambda): carry = lastgaelam, cur_value, last_return, next_q, main
return carry, (lastgaelam, last_return)
def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False):
lastgaelam = jnp.zeros_like(next_value) lastgaelam = jnp.zeros_like(next_value)
carry = lastgaelam, next_value last_return = next_q = next_value
_, advantages = jax.lax.scan( next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = lastgaelam, next_value, last_return, next_q, next_main
_, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
targets = values + advantages targets = values + advantages
if upgo:
advantages += return_t - values
targets = jax.lax.stop_gradient(targets) targets = jax.lax.stop_gradient(targets)
return targets, advantages return targets, advantages
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment