Commit 7c8b11c3 authored by sbl1996@126.com's avatar sbl1996@126.com

Add collect_steps

parent e1ff8f92
This diff is collapsed.
from functools import partial
from typing import NamedTuple
import jax
import jax.numpy as jnp
......@@ -193,6 +194,14 @@ def vtrace_rnad(
return targets, q_estimate
class VtraceCarry(NamedTuple):
v: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp
......@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main
carry = VtraceCarry(v, next_value, last_return, next_q, main)
return carry, (v, q_t, last_return)
def vtrace(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = v, next_value, last_return, next_q, next_main
carry = VtraceCarry(v, next_value, last_return, next_q, next_main)
_, (targets, q_estimate, return_t) = jax.lax.scan(
carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
advantages = q_estimate - values
if upgo:
advantages += return_t - values
......@@ -244,9 +260,18 @@ def vtrace(
return targets, advantages
class VtraceSepCarry(NamedTuple):
v: jnp.ndarray
next_values: jnp.ndarray
reward: jnp.ndarray
xi: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
(v1, next_values1, reward1, xi1, last_return1, next_q1), \
(v2, next_values2, reward2, xi2, last_return2, next_q2) = carry
ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1)
......@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2
return carry, (v, q_t, return_t)
carry1 = VtraceSepCarry(v1, next_values1, reward1, xi1, last_return1, next_q1)
carry2 = VtraceSepCarry(v2, next_values2, reward2, xi2, last_return2, next_q2)
return (carry1, carry2), (v, q_t, return_t)
def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
next_value1 = next_value
carry1 = VtraceSepCarry(
v=next_value1,
next_values=next_value1,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value1,
next_q=next_value1,
)
next_value2 = -next_value1
v1 = return1 = next_q1 = next_value1
v2 = return2 = next_q2 = next_value2
reward1 = reward2 = jnp.zeros_like(next_value)
xi1 = xi2 = jnp.ones_like(next_value)
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \
return1, return2, next_q1, next_q2
carry2 = VtraceSepCarry(
v=next_value2,
next_values=next_value2,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value2,
next_q=next_value2,
)
carry = carry1, carry2
_, (targets, q_estimate, return_t) = jax.lax.scan(
carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
advantages = q_estimate - values
if upgo:
advantages += return_t - values
......@@ -325,9 +368,18 @@ def vtrace_sep(
return targets, advantages
class GAESepCarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
reward: jnp.ndarray
done_used: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1), \
(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2) = carry
cur_value, next_done, reward, main = inp
main1 = main
main2 = ~main
......@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
return carry, (advantages, returns)
carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1)
carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2)
return (carry1, carry2), (advantages, returns)
def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
):
next_value1 = next_value
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_dones[-1])
done_used2 = jnp.ones_like(next_dones[-1])
reward1 = reward2 = jnp.zeros_like(next_value)
lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
_, (advantages, returns) = jax.lax.scan(
next_v, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, return_carry=False):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
carry1 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=next_value,
next_q=next_value,
)
carry2 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=-next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=-next_value,
next_q=-next_value,
)
carry = carry1, carry2
carry, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
targets = values + advantages
if upgo:
advantages += returns - values
......@@ -400,6 +463,14 @@ def truncated_gae_sep(
return targets, advantages
class GAECarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward, main = inp
......@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q = reward + discount * next_value
carry = lastgaelam, cur_value, last_return, next_q, main
carry = GAECarry(lastgaelam, cur_value, last_return, next_q, main)
return carry, (lastgaelam, last_return)
def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False):
lastgaelam = jnp.zeros_like(next_value)
last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = lastgaelam, next_value, last_return, next_q, next_main
_, (advantages, return_t) = jax.lax.scan(
next_v, values, rewards, next_dones, mains, gamma, gae_lambda,
upgo=False, return_carry=False):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
carry = GAECarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
last_return=next_value,
next_q=next_value,
next_main=jnp.ones_like(next_value, dtype=jnp.bool_),
)
carry, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
targets = values + advantages
if upgo:
advantages += return_t - values
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment