Commit 974fe861 authored by sbl1996@126.com's avatar sbl1996@126.com

Rename

parent dab0733b
......@@ -28,9 +28,9 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
ach_loss, policy_gradient_loss, vtrace, vtrace_2p0s, truncated_gae
from ygoai.rl.jax.switch import truncated_gae_2p0s as gae_2p0s_switch
ach_loss, policy_gradient_loss, vtrace, vtrace_sep, truncated_gae, truncated_gae_sep
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -745,21 +745,20 @@ def main():
if args.switch:
if args.value == "vtrace" or args.sep_value:
raise NotImplementedError
target_values, advantages = gae_2p0s_switch(
target_values, advantages = gae_sep_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios)
if args.value == "gae":
if args.sep_value:
raise NotImplementedError
target_values, advantages = truncated_gae(
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
target_values, advantages = adv_fn(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
vtrace_fn = vtrace_2p0s if args.sep_value else vtrace
target_values, advantages = vtrace_fn(
adv_fn = vtrace_sep if args.sep_value else vtrace
target_values, advantages = adv_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
......
......@@ -244,7 +244,7 @@ def vtrace(
return targets, advantages
def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
ratio, cur_values, next_done, r_t, main = inp
......@@ -301,7 +301,7 @@ def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
return carry, (v, q_t, return_t)
def vtrace_2p0s(
def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
):
......@@ -315,7 +315,7 @@ def vtrace_2p0s(
return1, return2, next_q1, next_q2
_, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_2p0s_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
advantages = q_estimate - values
......@@ -325,7 +325,7 @@ def vtrace_2p0s(
return targets, advantages
def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
cur_value, next_done, reward, main = inp
......@@ -375,7 +375,7 @@ def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
return carry, (advantages, returns)
def truncated_gae_2p0s(
def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
):
next_value1 = next_value
......@@ -390,12 +390,12 @@ def truncated_gae_2p0s(
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
_, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_upgo_loop, gamma=gamma, gae_lambda=gae_lambda),
partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
targets = values + advantages
if upgo:
advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets)
return targets, advantages
......
......@@ -2,7 +2,7 @@ import jax
import jax.numpy as jnp
def truncated_gae_2p0s(
def truncated_gae_sep(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
):
def body_fn(carry, inp):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment