Commit dab0733b authored by sbl1996@126.com's avatar sbl1996@126.com

Fix

parent 831e92ff
......@@ -752,13 +752,13 @@ def main():
# TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios)
if args.value == "gae":
if not args.sep_value:
if args.sep_value:
raise NotImplementedError
target_values, advantages = truncated_gae(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
vtrace_fn = vtrace if args.sep_value else vtrace_2p0s
vtrace_fn = vtrace_2p0s if args.sep_value else vtrace
target_values, advantages = vtrace_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment