Commit 7c8b11c3 authored by sbl1996@126.com's avatar sbl1996@126.com

Add collect_steps

parent e1ff8f92
......@@ -97,6 +97,8 @@ class Args:
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_steps: Optional[int] = None
"""the number of steps to compute the advantages"""
segment_length: Optional[int] = None
"""the length of the segment for training"""
anneal_lr: bool = False
......@@ -226,6 +228,7 @@ class Transition(NamedTuple):
dones: list
actions: list
logits: list
values: list
rewards: list
mains: list
next_dones: list
......@@ -304,6 +307,31 @@ def reshape_minibatch(
return x
def advantage_fn(
args, next_v, values, rewards, next_dones, switch_or_mains, ratios=None, return_carry=False):
if args.switch:
if args.value == "vtrace" or args.sep_value or return_carry:
raise NotImplementedError
return gae_sep_switch(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
if args.value == "gae":
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
return adv_fn(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo, return_carry=return_carry)
else:
adv_fn = vtrace_sep if args.sep_value else vtrace
if ratios is None:
ratios = jnp.ones_like(values)
return adv_fn(
next_v, ratios, values, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max,
return_carry=return_carry)
def rollout(
key: jax.random.PRNGKey,
args: Args,
......@@ -370,10 +398,17 @@ def rollout(
@jax.jit
def sample_action(
params, next_obs, rstate1, rstate2, main, done, key):
(rstate1, rstate2), logits = agent.apply(
params, next_obs, (rstate1, rstate2), done, main)[:2]
(rstate1, rstate2), logits, value = agent.apply(
params, next_obs, (rstate1, rstate2), done, main)[:3]
value = jnp.squeeze(value, axis=-1)
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
return next_obs, done, main, rstate1, rstate2, action, logits, value, key
@jax.jit
def compute_advantage_carry(
next_value, values, rewards, next_dones, mains):
return advantage_fn(
args, next_value, values, rewards, next_dones, mains, return_carry=True)
deck_names = args.deck_names
deck_avg_times = {name: 0 for name in deck_names}
......@@ -400,11 +435,17 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
init_rstates = []
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
return jax.tree.map(lambda *xs: jnp.stack(xs), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
......@@ -426,16 +467,18 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for k in range(args.num_steps):
for k in range(start_step, args.collect_steps):
if k % args.num_steps == 0:
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
init_rstates.append((init_rstate1, init_rstate2))
global_step += args.local_num_envs * n_actors * args.world_size
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
next_rstate1, next_rstate2, action, logits, value, key = sample_action(
params, next_obs, next_rstate1, next_rstate2, main, next_done, key)
cpu_action = np.array(action)
......@@ -453,6 +496,7 @@ def rollout(
mains=cached_main,
actions=action,
logits=logits,
values=value,
rewards=next_reward,
next_dones=next_done,
)
......@@ -495,8 +539,29 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage)
storage = []
start_step = args.collect_steps - args.num_steps
next_main = main_player == next_to_play
if args.collect_steps == args.num_steps:
storage_t = storage
storage = []
next_data = (next_obs, next_main)
else:
storage_t = storage[:args.num_steps]
storage = storage[args.num_steps:]
values, rewards, next_dones, mains = prepare_data([
(t.values, t.rewards, t.next_dones, t.mains) for t in storage])
next_value = sample_action(
params, next_obs, next_rstate1, next_rstate2, next_main, next_done, key)[-2]
next_value = jnp.where(next_main, next_value, -next_value)
adv_carry = compute_advantage_carry(
next_value, values, rewards, next_dones, mains)
next_data = adv_carry
partitioned_storage = jax.tree.map(
lambda x: jnp.split(x, len(learner_devices), axis=1), prepare_data(storage_t))
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
......@@ -508,10 +573,11 @@ def rollout(
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
init_rstate = init_rstates.pop(0)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, next_obs, next_main))
(init_rstate, next_data))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
......@@ -594,6 +660,8 @@ def main():
args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.segment_length is not None:
assert args.num_steps % args.segment_length == 0, "num_steps must be divisible by segment_length"
args.collect_steps = args.collect_steps or args.num_steps
assert args.collect_steps >= args.num_steps, "collect_steps must be greater than or equal to num_steps"
if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file)
......@@ -734,10 +802,10 @@ def main():
else:
eval_variables = None
def advantage_fn(
def compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value):
num_envs = jax.tree.leaves(next_value)[0].shape[0]
actions, logits, rewards, next_v):
num_envs = jax.tree.leaves(next_v)[0].shape[0]
num_steps = next_dones.shape[0] // num_envs
def reshape_time_series(x):
......@@ -745,37 +813,20 @@ def main():
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
ratios = reshape_time_series(ratios)
new_values_, rewards, next_dones, switch_or_mains = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch_or_mains),
)
# Advantages and target values
if args.switch:
if args.value == "vtrace" or args.sep_value:
raise NotImplementedError
target_values, advantages = gae_sep_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios)
if args.value == "gae":
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
target_values, advantages = adv_fn(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
adv_fn = vtrace_sep if args.sep_value else vtrace
target_values, advantages = adv_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
target_values, advantages = advantage_fn(
args, next_v, new_values_, rewards, next_dones, switch_or_mains, ratios)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
return target_values, advantages
def loss_fn(
def compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None):
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
......@@ -820,17 +871,24 @@ def main():
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains):
def apply_fn(
variables, obs, init_rstate, dones, next_dones, switch_or_mains, train=True):
if args.switch:
dones = dones | next_dones
((rstate1, rstate2), new_logits, new_values, _), state_updates = agent.apply(
variables, obs, (rstate1, rstate2), dones, switch_or_mains,
train=True, mutable=["batch_stats"])
mutable = ["batch_stats"] if train else False
rets = agent.apply(
variables, obs, init_rstate, dones, switch_or_mains,
train=train, mutable=mutable)
if train:
((rstate1, rstate2), new_logits, new_values, _), state_updates = rets
else:
(rstate1, rstate2), new_logits, new_values, _ = rets
state_updates = {}
new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values)
return ((rstate1, rstate2), new_logits, new_values), state_updates
def compute_next_value(
variables, rstate1, rstate2, next_obs, next_main):
def compute_next_value(variables, next_rstate, next_obs, next_main):
rstate1, rstate2 = next_rstate
rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2)
next_value = agent.apply(variables, next_obs, rstate)[2]
......@@ -840,39 +898,39 @@ def main():
next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
return next_value
def compute_advantage(
variables, rstate1, rstate2, obs, dones, next_dones,
def get_advantage(
variables, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_obs, next_main):
segment_length = dones.shape[0]
num_steps = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
jax.tree.map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
(next_rstate, new_logits, new_values), state_updates = apply_fn(
variables, obs, init_rstate, dones, next_dones, switch_or_mains, train=False)
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
variables, next_rstate, next_obs, next_main)
target_values, advantages = advantage_fn(
target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (segment_length, -1) + x.shape[2:]),
lambda x: jnp.reshape(x, (num_steps, -1) + x.shape[2:]),
(target_values, advantages))
return target_values, advantages
def compute_loss(
params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
def get_loss(
params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask):
variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
variables, obs, init_rstate, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None)
......@@ -881,23 +939,27 @@ def main():
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def compute_advantage_loss(
params, batch_stats, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_obs, next_main):
num_envs = jax.tree.leaves(next_main)[0].shape[0]
def get_advantage_loss(
params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_data):
num_envs = jax.tree.leaves(next_data)[0].shape[0]
variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)
variables = {'params': params, 'batch_stats': state_updates['batch_stats']}
next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main)
(next_rstate, new_logits, new_values), state_updates = apply_fn(
variables, obs, init_rstate, dones, next_dones, switch_or_mains)
if args.collect_steps == args.num_steps:
next_obs, next_main = next_data
variables = {'params': params, 'batch_stats': state_updates['batch_stats']}
next_v = compute_next_value(
variables, next_rstate, next_obs, next_main)
else:
next_v = next_data
target_values, advantages = advantage_fn(
target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
actions, logits, rewards, next_v)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn(
loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=dones.shape[0] // num_envs)
......@@ -908,18 +970,15 @@ def main():
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_obs: List,
sharded_next_main: List,
sharded_init_rstate: List,
sharded_next_data: List,
key: jax.random.PRNGKey,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs, init_rstate1, init_rstate2 = [
next_data, init_rstate = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_obs, sharded_init_rstate1, sharded_init_rstate2]
for x in [sharded_next_data, sharded_init_rstate]
]
next_main = jnp.concatenate(sharded_next_main)
# reorder storage of individual players
# main first, opponent second
......@@ -934,9 +993,10 @@ def main():
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if args.segment_length is None:
loss_grad_fn = jax.value_and_grad(compute_advantage_loss, has_aux=True)
loss_grad_fn = jax.value_and_grad(get_advantage_loss, has_aux=True)
else:
loss_grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
# TODO: fix it
loss_grad_fn = jax.value_and_grad(get_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
......@@ -947,9 +1007,9 @@ def main():
return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
b_init_rstate1, b_init_rstate2, b_next_obs, b_next_main = \
b_init_rstate, b_next_data = \
jax.tree.map(partial(convert_data, multi_step=False),
(init_rstate1, init_rstate2, next_obs, next_main))
(init_rstate, next_data))
b_storage = jax.tree.map(convert_data, storage)
if args.switch:
switch_or_mains = convert_data(switch)
......@@ -969,31 +1029,30 @@ def main():
else:
def update_minibatch(carry, minibatch):
def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry
minibatch_t = rstate1, rstate2, *minibatch_t
(loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \
agent_state, init_rstate = carry
minibatch_t = init_rstate, *minibatch_t
(loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, next_rstate)), \
grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl)
return (agent_state, next_rstate), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t, mask = minibatch
target_values, advantages = compute_advantage(
get_variables(carry), rstate1, rstate2, *minibatch_t)
init_rstate, *minibatch_t, mask = minibatch
target_values, advantages = get_advantage(
get_variables(carry), init_rstate, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(carry, _rstate1, _rstate2), \
(carry, _next_rstate), \
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (carry, rstate1, rstate2), minibatch_t)
update_minibatch_t, (carry, init_rstate), minibatch_t)
return carry, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
b_init_rstate1,
b_init_rstate2,
b_init_rstate,
b_storage.obs,
b_storage.dones,
b_storage.next_dones,
......@@ -1002,8 +1061,7 @@ def main():
b_storage.logits,
b_rewards,
b_mask,
b_next_obs,
b_next_main,
b_next_data,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
......@@ -1023,7 +1081,7 @@ def main():
axis_name="main_devices",
devices=global_main_devices,
)
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
......
from functools import partial
from typing import NamedTuple
import jax
import jax.numpy as jnp
......@@ -193,6 +194,14 @@ def vtrace_rnad(
return targets, q_estimate
class VtraceCarry(NamedTuple):
v: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp
......@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main
carry = VtraceCarry(v, next_value, last_return, next_q, main)
return carry, (v, q_t, last_return)
def vtrace(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
):
v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = v, next_value, last_return, next_q, next_main
_, (targets, q_estimate, return_t) = jax.lax.scan(
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = VtraceCarry(v, next_value, last_return, next_q, next_main)
carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
advantages = q_estimate - values
if upgo:
advantages += return_t - values
......@@ -244,9 +260,18 @@ def vtrace(
return targets, advantages
class VtraceSepCarry(NamedTuple):
v: jnp.ndarray
next_values: jnp.ndarray
reward: jnp.ndarray
xi: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2 = carry
(v1, next_values1, reward1, xi1, last_return1, next_q1), \
(v2, next_values2, reward2, xi2, last_return2, next_q2) = carry
ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1)
......@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \
last_return1, last_return2, next_q1, next_q2
return carry, (v, q_t, return_t)
carry1 = VtraceSepCarry(v1, next_values1, reward1, xi1, last_return1, next_q1)
carry2 = VtraceSepCarry(v2, next_values2, reward2, xi2, last_return2, next_q2)
return (carry1, carry2), (v, q_t, return_t)
def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False,
next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
):
next_value1 = next_value
next_value2 = -next_value1
v1 = return1 = next_q1 = next_value1
v2 = return2 = next_q2 = next_value2
reward1 = reward2 = jnp.zeros_like(next_value)
xi1 = xi2 = jnp.ones_like(next_value)
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \
return1, return2, next_q1, next_q2
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
next_value1 = next_value
carry1 = VtraceSepCarry(
v=next_value1,
next_values=next_value1,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value1,
next_q=next_value1,
)
next_value2 = -next_value1
carry2 = VtraceSepCarry(
v=next_value2,
next_values=next_value2,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value2,
next_q=next_value2,
)
carry = carry1, carry2
_, (targets, q_estimate, return_t) = jax.lax.scan(
carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
advantages = q_estimate - values
if upgo:
advantages += return_t - values
......@@ -325,9 +368,18 @@ def vtrace_sep(
return targets, advantages
class GAESepCarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
reward: jnp.ndarray
done_used: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry
(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1), \
(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2) = carry
cur_value, next_done, reward, main = inp
main1 = main
main2 = ~main
......@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
return carry, (advantages, returns)
carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1)
carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2)
return (carry1, carry2), (advantages, returns)
def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo,
):
next_value1 = next_value
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_dones[-1])
done_used2 = jnp.ones_like(next_dones[-1])
reward1 = reward2 = jnp.zeros_like(next_value)
lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2
_, (advantages, returns) = jax.lax.scan(
next_v, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, return_carry=False):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
carry1 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=next_value,
next_q=next_value,
)
carry2 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=-next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=-next_value,
next_q=-next_value,
)
carry = carry1, carry2
carry, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
targets = values + advantages
if upgo:
advantages += returns - values
......@@ -400,6 +463,14 @@ def truncated_gae_sep(
return targets, advantages
class GAECarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward, main = inp
......@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q = reward + discount * next_value
carry = lastgaelam, cur_value, last_return, next_q, main
carry = GAECarry(lastgaelam, cur_value, last_return, next_q, main)
return carry, (lastgaelam, last_return)
def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False):
lastgaelam = jnp.zeros_like(next_value)
last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = lastgaelam, next_value, last_return, next_q, next_main
_, (advantages, return_t) = jax.lax.scan(
next_v, values, rewards, next_dones, mains, gamma, gae_lambda,
upgo=False, return_carry=False):
if isinstance(next_v, (tuple, list)):
carry = next_v
else:
next_value = next_v
carry = GAECarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
last_return=next_value,
next_q=next_value,
next_main=jnp.ones_like(next_value, dtype=jnp.bool_),
)
carry, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True
)
if return_carry:
return carry
targets = values + advantages
if upgo:
advantages += return_t - values
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment