Commit 5cd9807d authored by sbl1996@126.com's avatar sbl1996@126.com

Fix gae with upgo

parent 20da4bcc
...@@ -262,7 +262,7 @@ def vtrace( ...@@ -262,7 +262,7 @@ def vtrace(
class VtraceSepCarry(NamedTuple): class VtraceSepCarry(NamedTuple):
v: jnp.ndarray v: jnp.ndarray
next_values: jnp.ndarray next_value: jnp.ndarray
reward: jnp.ndarray reward: jnp.ndarray
xi: jnp.ndarray xi: jnp.ndarray
last_return: jnp.ndarray last_return: jnp.ndarray
...@@ -270,22 +270,18 @@ class VtraceSepCarry(NamedTuple): ...@@ -270,22 +270,18 @@ class VtraceSepCarry(NamedTuple):
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
(v1, next_values1, reward1, xi1, last_return1, next_q1), \ (v1, next_value1, reward1, xi1, last_return1, next_q1), \
(v2, next_values2, reward2, xi2, last_return2, next_q2) = carry (v2, next_value2, reward2, xi2, last_return2, next_q2) = carry
ratio, cur_values, next_done, r_t, main = inp ratio, cur_value, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1)
v2 = jnp.where(next_done, 0, v2) v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2 = jax.tree.map(
next_values1 = jnp.where(next_done, 0, next_values1) lambda x: jnp.where(next_done, 0, x),
next_values2 = jnp.where(next_done, 0, next_values2) (v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2))
reward1 = jnp.where(next_done, 0, reward1)
reward2 = jnp.where(next_done, 0, reward2)
xi1 = jnp.where(next_done, 1, xi1)
xi2 = jnp.where(next_done, 1, xi2)
discount = gamma * (1.0 - next_done) discount = gamma * (1.0 - next_done)
v = jnp.where(main, v1, v2) v = jnp.where(main, v1, v2)
next_values = jnp.where(main, next_values1, next_values2) next_value = jnp.where(main, next_value1, next_value2)
reward = jnp.where(main, reward1, reward2) reward = jnp.where(main, reward1, reward2)
xi = jnp.where(main, xi1, xi2) xi = jnp.where(main, xi1, xi2)
...@@ -293,20 +289,20 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -293,20 +289,20 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
rho_t = jnp.clip(ratio * xi, rho_min, rho_max) rho_t = jnp.clip(ratio * xi, rho_min, rho_max)
c_t = jnp.clip(ratio * xi, c_min, c_max) c_t = jnp.clip(ratio * xi, c_min, c_max)
sig_v = rho_t * (r_t + ratio * reward + discount * next_values - cur_values) sig_v = rho_t * (r_t + ratio * reward + discount * next_value - cur_value)
v = cur_values + sig_v + c_t * discount * (v - next_values) v = cur_value + sig_v + c_t * discount * (v - next_value)
# UPGO advantage (not corrected by importance sampling, unlike V-trace) # UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t = jnp.where(main, last_return1, last_return2) return_t = jnp.where(main, last_return1, last_return2)
next_q = jnp.where(main, next_q1, next_q2) next_q = jnp.where(main, next_q1, next_q2)
factor = jnp.where(main, jnp.ones_like(r_t), -jnp.ones_like(r_t)) factor = jnp.where(main, jnp.ones_like(r_t), -jnp.ones_like(r_t))
return_t = r_t + discount * jnp.where( return_t = r_t + discount * jnp.where(
next_q >= next_values, return_t, next_values) next_q >= next_value, return_t, next_value)
last_return1 = jnp.where( last_return1 = jnp.where(
next_done, r_t * factor, jnp.where(main, return_t, last_return1)) next_done, r_t * factor, jnp.where(main, return_t, last_return1))
last_return2 = jnp.where( last_return2 = jnp.where(
next_done, r_t * -factor, jnp.where(main, last_return2, return_t)) next_done, r_t * -factor, jnp.where(main, last_return2, return_t))
next_q = r_t + discount * next_values next_q = r_t + discount * next_value
next_q1 = jnp.where( next_q1 = jnp.where(
next_done, r_t * factor, jnp.where(main, next_q, next_q1)) next_done, r_t * factor, jnp.where(main, next_q, next_q1))
next_q2 = jnp.where( next_q2 = jnp.where(
...@@ -314,15 +310,15 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -314,15 +310,15 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1 = jnp.where(main, v, v1) v1 = jnp.where(main, v, v1)
v2 = jnp.where(main, v2, v) v2 = jnp.where(main, v2, v)
next_values1 = jnp.where(main, cur_values, next_values1) next_value1 = jnp.where(main, cur_value, next_value1)
next_values2 = jnp.where(main, next_values2, cur_values) next_value2 = jnp.where(main, next_value2, cur_value)
reward1 = jnp.where(main, 0, -r_t + ratio * reward1) reward1 = jnp.where(main, 0, -r_t + ratio * reward1)
reward2 = jnp.where(main, -r_t + ratio * reward2, 0) reward2 = jnp.where(main, -r_t + ratio * reward2, 0)
xi1 = jnp.where(main, 1, ratio * xi1) xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1) xi2 = jnp.where(main, ratio * xi2, 1)
carry1 = VtraceSepCarry(v1, next_values1, reward1, xi1, last_return1, next_q1) carry1 = VtraceSepCarry(v1, next_value1, reward1, xi1, last_return1, next_q1)
carry2 = VtraceSepCarry(v2, next_values2, reward2, xi2, last_return2, next_q2) carry2 = VtraceSepCarry(v2, next_value2, reward2, xi2, last_return2, next_q2)
return (carry1, carry2), (v, q_t, return_t) return (carry1, carry2), (v, q_t, return_t)
...@@ -338,7 +334,7 @@ def vtrace_sep( ...@@ -338,7 +334,7 @@ def vtrace_sep(
next_value1 = next_value next_value1 = next_value
carry1 = VtraceSepCarry( carry1 = VtraceSepCarry(
v=next_value1, v=next_value1,
next_values=next_value1, next_value=next_value1,
reward=jnp.zeros_like(next_value), reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value), xi=jnp.ones_like(next_value),
last_return=next_value1, last_return=next_value1,
...@@ -347,7 +343,7 @@ def vtrace_sep( ...@@ -347,7 +343,7 @@ def vtrace_sep(
next_value2 = -next_value1 next_value2 = -next_value1
carry2 = VtraceSepCarry( carry2 = VtraceSepCarry(
v=next_value2, v=next_value2,
next_values=next_value2, next_value=next_value2,
reward=jnp.zeros_like(next_value), reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value), xi=jnp.ones_like(next_value),
last_return=next_value2, last_return=next_value2,
...@@ -397,23 +393,22 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): ...@@ -397,23 +393,22 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
done_used2 = jnp.where( done_used2 = jnp.where(
next_done, main2, jnp.where(main2 & ~done_used2, True, done_used2)) next_done, main2, jnp.where(main2 & ~done_used2, True, done_used2))
# UPGO advantage
last_return1 = jnp.where(real_done1, 0, last_return1) last_return1 = jnp.where(real_done1, 0, last_return1)
last_return2 = jnp.where(real_done2, 0, last_return2) last_return2 = jnp.where(real_done2, 0, last_return2)
last_return1_ = reward1 + gamma * jnp.where( last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1) next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where( last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2) next_q2 >= next_value2, last_return2, next_value2)
return_t = jnp.where(main1, last_return1_, last_return2_)
last_return1 = jnp.where(main1, last_return1_, last_return1)
last_return2 = jnp.where(main2, last_return2_, last_return2)
next_q1_ = reward1 + gamma * next_value1 next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2 next_q2_ = reward2 + gamma * next_value2
next_q1 = jnp.where(main1, next_q1_, next_q1) next_q1 = jnp.where(main1, next_q1_, next_q1)
next_q2 = jnp.where(main2, next_q2_, next_q1) next_q2 = jnp.where(main2, next_q2_, next_q2)
last_return1 = jnp.where(main1, last_return1_, last_return1)
last_return2 = jnp.where(main2, last_return2_, last_return2)
returns = jnp.where(main1, last_return1_, last_return2_)
delta1 = next_q1_ - cur_value delta1 = reward1 + gamma * next_value1 - cur_value
delta2 = next_q2_ - cur_value delta2 = reward2 + gamma * next_value2 - cur_value
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1 lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2 lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(main1, lastgaelam1_, lastgaelam2_) advantages = jnp.where(main1, lastgaelam1_, lastgaelam2_)
...@@ -424,7 +419,7 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): ...@@ -424,7 +419,7 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1) carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1)
carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2) carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2)
return (carry1, carry2), (advantages, returns) return (carry1, carry2), (advantages, return_t)
def truncated_gae_sep( def truncated_gae_sep(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment