Commit 974fe861 authored by sbl1996@126.com's avatar sbl1996@126.com

Rename

parent dab0733b
...@@ -28,9 +28,9 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files ...@@ -28,9 +28,9 @@ from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent import RNNAgent, ModelArgs from ygoai.rl.jax.agent import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax.switch import truncated_gae_sep as gae_sep_switch
from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \ from ygoai.rl.jax import clipped_surrogate_pg_loss, mse_loss, entropy_loss, simple_policy_loss, \
ach_loss, policy_gradient_loss, vtrace, vtrace_2p0s, truncated_gae ach_loss, policy_gradient_loss, vtrace, vtrace_sep, truncated_gae, truncated_gae_sep
from ygoai.rl.jax.switch import truncated_gae_2p0s as gae_2p0s_switch
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...@@ -745,21 +745,20 @@ def main(): ...@@ -745,21 +745,20 @@ def main():
if args.switch: if args.switch:
if args.value == "vtrace" or args.sep_value: if args.value == "vtrace" or args.sep_value:
raise NotImplementedError raise NotImplementedError
target_values, advantages = gae_2p0s_switch( target_values, advantages = gae_sep_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains, next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo) args.gamma, args.gae_lambda, args.upgo)
else: else:
# TODO: TD(lambda) for multi-step # TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios) ratios_ = reshape_time_series(ratios)
if args.value == "gae": if args.value == "gae":
if args.sep_value: adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
raise NotImplementedError target_values, advantages = adv_fn(
target_values, advantages = truncated_gae(
next_value, new_values_, rewards, next_dones, switch_or_mains, next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo) args.gamma, args.gae_lambda, args.upgo)
else: else:
vtrace_fn = vtrace_2p0s if args.sep_value else vtrace adv_fn = vtrace_sep if args.sep_value else vtrace
target_values, advantages = vtrace_fn( target_values, advantages = adv_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma, next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max) args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
......
...@@ -244,7 +244,7 @@ def vtrace( ...@@ -244,7 +244,7 @@ def vtrace(
return targets, advantages return targets, advantages
def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry last_return1, last_return2, next_q1, next_q2 = carry
ratio, cur_values, next_done, r_t, main = inp ratio, cur_values, next_done, r_t, main = inp
...@@ -301,7 +301,7 @@ def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -301,7 +301,7 @@ def vtrace_2p0s_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
return carry, (v, q_t, return_t) return carry, (v, q_t, return_t)
def vtrace_2p0s( def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains, next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False, gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
): ):
...@@ -315,7 +315,7 @@ def vtrace_2p0s( ...@@ -315,7 +315,7 @@ def vtrace_2p0s(
return1, return2, next_q1, next_q2 return1, return2, next_q1, next_q2
_, (targets, q_estimate, return_t) = jax.lax.scan( _, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_2p0s_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max), partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True carry, (ratios, values, next_dones, rewards, mains), reverse=True
) )
advantages = q_estimate - values advantages = q_estimate - values
...@@ -325,7 +325,7 @@ def vtrace_2p0s( ...@@ -325,7 +325,7 @@ def vtrace_2p0s(
return targets, advantages return targets, advantages
def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda): def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
cur_value, next_done, reward, main = inp cur_value, next_done, reward, main = inp
...@@ -375,7 +375,7 @@ def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda): ...@@ -375,7 +375,7 @@ def truncated_gae_upgo_loop(carry, inp, gamma, gae_lambda):
return carry, (advantages, returns) return carry, (advantages, returns)
def truncated_gae_2p0s( def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
): ):
next_value1 = next_value next_value1 = next_value
...@@ -390,12 +390,12 @@ def truncated_gae_2p0s( ...@@ -390,12 +390,12 @@ def truncated_gae_2p0s(
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
_, (advantages, returns) = jax.lax.scan( _, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_upgo_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
targets = values + advantages
if upgo: if upgo:
advantages += returns - values advantages += returns - values
targets = values + advantages
targets = jax.lax.stop_gradient(targets) targets = jax.lax.stop_gradient(targets)
return targets, advantages return targets, advantages
......
...@@ -2,7 +2,7 @@ import jax ...@@ -2,7 +2,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
def truncated_gae_2p0s( def truncated_gae_sep(
next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo
): ):
def body_fn(carry, inp): def body_fn(carry, inp):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment