Commit 8cebfebf authored by sbl1996@126.com's avatar sbl1996@126.com

Recompute advantages every minibatch

parent dd06205b
......@@ -259,7 +259,7 @@ def reshape_minibatch(
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // num_minibatches1
# n_mb_e = num_minibatches // n_mb_t
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
......@@ -727,8 +727,8 @@ def main():
eval_params = None
def advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value):
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value):
num_envs = next_value.shape[0]
num_steps = next_dones.shape[0] // num_envs
......@@ -815,12 +815,23 @@ def main():
def compute_advantage(
params, rstate1, rstate2, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_value):
segment_length = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
jax.tree.map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
new_logits, new_values = apply_fn(
params, obs, rstate1, rstate2, dones, next_dones, switch_or_mains)[1:3]
target_values, advantages = advantage_fn(
new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (segment_length, -1) + x.shape[2:]),
(target_values, advantages))
return target_values, advantages
def compute_loss(
......@@ -888,40 +899,6 @@ def main():
else:
loss_grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
def compute_advantage_t(next_value):
N = args.num_minibatches // 4
def convert_data1(x: jnp.ndarray, multi_step=True):
return reshape_minibatch(x, multi_step, N, num_steps)
b_init_rstate1, b_init_rstate2, b_next_value = jax.tree.map(
partial(convert_data1, multi_step=False), (init_rstate1, init_rstate2, next_value))
b_storage = jax.tree.map(convert_data1, storage)
if args.switch:
b_switch_or_mains = convert_data1(switch)
else:
b_switch_or_mains = b_storage.mains
target_values, advantages = jax.lax.scan(
lambda x, y: (x, compute_advantage(x, *y)),
agent_state.params,
(
b_init_rstate1,
b_init_rstate2,
b_storage.obs,
b_storage.dones,
b_storage.next_dones,
b_switch_or_mains,
b_storage.actions,
b_storage.logits,
b_storage.rewards,
b_next_value,
))[1]
target_values, advantages = jax.tree.map(
partial(reshape_batch, num_minibatches=N, num_steps=num_steps),
(target_values, advantages))
return target_values, advantages
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
......@@ -938,7 +915,6 @@ def main():
return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2))
shuffled_storage = jax.tree.map(convert_data, storage)
......@@ -947,10 +923,9 @@ def main():
else:
switch_or_mains = shuffled_storage.mains
shuffled_mask = ~shuffled_storage.dones
shuffled_next_value = convert_data(next_value, multi_step=False)
if args.segment_length is None:
shuffled_next_value = convert_data(next_value, multi_step=False)
others = shuffled_storage.rewards, shuffled_next_value, shuffled_mask
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, ent_loss, approx_kl)), grads = loss_grad_fn(
agent_state.params, *minibatch)
......@@ -958,10 +933,6 @@ def main():
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl)
else:
target_values, advantages = compute_advantage_t(next_value)
shuffled_target_values, shuffled_advantages = jax.tree.map(
convert_data, (target_values, advantages))
others = shuffled_target_values, shuffled_advantages, shuffled_mask
def update_minibatch(agent_state, minibatch):
def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry
......@@ -972,7 +943,11 @@ def main():
agent_state = agent_state.apply_gradients(grads=grads)
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t = minibatch
rstate1, rstate2, *minibatch_t, mask = minibatch
target_values, advantages = compute_advantage(
agent_state.params, rstate1, rstate2, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(agent_state, _rstate1, _rstate2), \
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (agent_state, rstate1, rstate2), minibatch_t)
......@@ -990,7 +965,9 @@ def main():
switch_or_mains,
shuffled_storage.actions,
shuffled_storage.logits,
*others,
shuffled_storage.rewards,
shuffled_next_value,
shuffled_mask
),
)
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment