Commit 7c8b11c3 authored by sbl1996@126.com's avatar sbl1996@126.com

Add collect_steps

parent e1ff8f92
This diff is collapsed.
from functools import partial from functools import partial
from typing import NamedTuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -193,6 +194,14 @@ def vtrace_rnad( ...@@ -193,6 +194,14 @@ def vtrace_rnad(
return targets, q_estimate return targets, q_estimate
class VtraceCarry(NamedTuple):
v: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp ratio, cur_value, next_done, reward, main = inp
...@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q = reward + discount * next_value next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main carry = VtraceCarry(v, next_value, last_return, next_q, main)
return carry, (v, q_t, last_return) return carry, (v, q_t, last_return)
def vtrace( def vtrace(
next_value, ratios, values, rewards, next_dones, mains, next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False, gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
): ):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
v = last_return = next_q = next_value v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_) next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = v, next_value, last_return, next_q, next_main carry = VtraceCarry(v, next_value, last_return, next_q, next_main)
_, (targets, q_estimate, return_t) = jax.lax.scan( carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max), partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True carry, (ratios, values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
advantages = q_estimate - values advantages = q_estimate - values
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
...@@ -244,9 +260,18 @@ def vtrace( ...@@ -244,9 +260,18 @@ def vtrace(
return targets, advantages return targets, advantages
class VtraceSepCarry(NamedTuple):
v: jnp.ndarray
next_values: jnp.ndarray
reward: jnp.ndarray
xi: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ (v1, next_values1, reward1, xi1, last_return1, next_q1), \
last_return1, last_return2, next_q1, next_q2 = carry (v2, next_values2, reward2, xi2, last_return2, next_q2) = carry
ratio, cur_values, next_done, r_t, main = inp ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1) v1 = jnp.where(next_done, 0, v1)
...@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1 = jnp.where(main, 1, ratio * xi1) xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1) xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ carry1 = VtraceSepCarry(v1, next_values1, reward1, xi1, last_return1, next_q1)
last_return1, last_return2, next_q1, next_q2 carry2 = VtraceSepCarry(v2, next_values2, reward2, xi2, last_return2, next_q2)
return carry, (v, q_t, return_t) return (carry1, carry2), (v, q_t, return_t)
def vtrace_sep( def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains, next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False, gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
): ):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
next_value1 = next_value next_value1 = next_value
carry1 = VtraceSepCarry(
v=next_value1,
next_values=next_value1,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value1,
next_q=next_value1,
)
next_value2 = -next_value1 next_value2 = -next_value1
v1 = return1 = next_q1 = next_value1 carry2 = VtraceSepCarry(
v2 = return2 = next_q2 = next_value2 v=next_value2,
reward1 = reward2 = jnp.zeros_like(next_value) next_values=next_value2,
xi1 = xi2 = jnp.ones_like(next_value) reward=jnp.zeros_like(next_value),
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \ xi=jnp.ones_like(next_value),
return1, return2, next_q1, next_q2 last_return=next_value2,
next_q=next_value2,
)
carry = carry1, carry2
_, (targets, q_estimate, return_t) = jax.lax.scan( carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max), partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True carry, (ratios, values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
advantages = q_estimate - values advantages = q_estimate - values
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
...@@ -325,9 +368,18 @@ def vtrace_sep( ...@@ -325,9 +368,18 @@ def vtrace_sep(
return targets, advantages return targets, advantages
class GAESepCarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
reward: jnp.ndarray
done_used: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ (lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1), \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry (lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2) = carry
cur_value, next_done, reward, main = inp cur_value, next_done, reward, main = inp
main1 = main main1 = main
main2 = ~main main2 = ~main
...@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): ...@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1) lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2) lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1)
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2)
return carry, (advantages, returns) return (carry1, carry2), (advantages, returns)
def truncated_gae_sep( def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, next_v, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, return_carry=False):
): if isinstance(next_v, (tuple, list)):
next_value1 = next_value carry = next_v
next_value2 = -next_value1 else:
last_return1 = next_q1 = next_value1 next_value = next_v
last_return2 = next_q2 = next_value2 carry1 = GAESepCarry(
done_used1 = jnp.ones_like(next_dones[-1]) lastgaelam=jnp.zeros_like(next_value),
done_used2 = jnp.ones_like(next_dones[-1]) next_value=next_value,
reward1 = reward2 = jnp.zeros_like(next_value) reward=jnp.zeros_like(next_value),
lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value) done_used=jnp.ones_like(next_dones[-1]),
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ last_return=next_value,
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 next_q=next_value,
)
_, (advantages, returns) = jax.lax.scan( carry2 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=-next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=-next_value,
next_q=-next_value,
)
carry = carry1, carry2
carry, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
targets = values + advantages targets = values + advantages
if upgo: if upgo:
advantages += returns - values advantages += returns - values
...@@ -400,6 +463,14 @@ def truncated_gae_sep( ...@@ -400,6 +463,14 @@ def truncated_gae_sep(
return targets, advantages return targets, advantages
class GAECarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def truncated_gae_loop(carry, inp, gamma, gae_lambda): def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value, last_return, next_q, next_main = carry lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward, main = inp cur_value, next_done, reward, main = inp
...@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda): ...@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q = reward + discount * next_value next_q = reward + discount * next_value
carry = lastgaelam, cur_value, last_return, next_q, main carry = GAECarry(lastgaelam, cur_value, last_return, next_q, main)
return carry, (lastgaelam, last_return) return carry, (lastgaelam, last_return)
def truncated_gae( def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False): next_v, values, rewards, next_dones, mains, gamma, gae_lambda,
lastgaelam = jnp.zeros_like(next_value) upgo=False, return_carry=False):
last_return = next_q = next_value if isinstance(next_v, (tuple, list)):
next_main = jnp.ones_like(next_value, dtype=jnp.bool_) carry = next_v
carry = lastgaelam, next_value, last_return, next_q, next_main else:
_, (advantages, return_t) = jax.lax.scan( next_value = next_v
carry = GAECarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
last_return=next_value,
next_q=next_value,
next_main=jnp.ones_like(next_value, dtype=jnp.bool_),
)
carry, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
targets = values + advantages targets = values + advantages
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment