Commit 7c8b11c3 authored by sbl1996@126.com's avatar sbl1996@126.com

Add collect_steps

parent e1ff8f92
...@@ -97,6 +97,8 @@ class Args: ...@@ -97,6 +97,8 @@ class Args:
"""the number of actor threads to use""" """the number of actor threads to use"""
num_steps: int = 128 num_steps: int = 128
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
collect_steps: Optional[int] = None
"""the number of steps to compute the advantages"""
segment_length: Optional[int] = None segment_length: Optional[int] = None
"""the length of the segment for training""" """the length of the segment for training"""
anneal_lr: bool = False anneal_lr: bool = False
...@@ -226,6 +228,7 @@ class Transition(NamedTuple): ...@@ -226,6 +228,7 @@ class Transition(NamedTuple):
dones: list dones: list
actions: list actions: list
logits: list logits: list
values: list
rewards: list rewards: list
mains: list mains: list
next_dones: list next_dones: list
...@@ -304,6 +307,31 @@ def reshape_minibatch( ...@@ -304,6 +307,31 @@ def reshape_minibatch(
return x return x
def advantage_fn(
args, next_v, values, rewards, next_dones, switch_or_mains, ratios=None, return_carry=False):
if args.switch:
if args.value == "vtrace" or args.sep_value or return_carry:
raise NotImplementedError
return gae_sep_switch(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
if args.value == "gae":
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
return adv_fn(
next_v, values, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo, return_carry=return_carry)
else:
adv_fn = vtrace_sep if args.sep_value else vtrace
if ratios is None:
ratios = jnp.ones_like(values)
return adv_fn(
next_v, ratios, values, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max,
return_carry=return_carry)
def rollout( def rollout(
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
args: Args, args: Args,
...@@ -370,10 +398,17 @@ def rollout( ...@@ -370,10 +398,17 @@ def rollout(
@jax.jit @jax.jit
def sample_action( def sample_action(
params, next_obs, rstate1, rstate2, main, done, key): params, next_obs, rstate1, rstate2, main, done, key):
(rstate1, rstate2), logits = agent.apply( (rstate1, rstate2), logits, value = agent.apply(
params, next_obs, (rstate1, rstate2), done, main)[:2] params, next_obs, (rstate1, rstate2), done, main)[:3]
value = jnp.squeeze(value, axis=-1)
action, key = categorical_sample(logits, key) action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key return next_obs, done, main, rstate1, rstate2, action, logits, value, key
@jax.jit
def compute_advantage_carry(
next_value, values, rewards, next_dones, mains):
return advantage_fn(
args, next_value, values, rewards, next_dones, mains, return_carry=True)
deck_names = args.deck_names deck_names = args.deck_names
deck_avg_times = {name: 0 for name in deck_names} deck_avg_times = {name: 0 for name in deck_names}
...@@ -400,11 +435,17 @@ def rollout( ...@@ -400,11 +435,17 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(main_player) np.random.shuffle(main_player)
start_step = 0
storage = [] storage = []
init_rstates = []
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage) return jax.tree.map(lambda *xs: jnp.stack(xs), *storage)
for update in range(1, args.num_updates + 2): for update in range(1, args.num_updates + 2):
if update == 10: if update == 10:
...@@ -426,16 +467,18 @@ def rollout( ...@@ -426,16 +467,18 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start) params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time() rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map( for k in range(start_step, args.collect_steps):
lambda x: x.copy(), (next_rstate1, next_rstate2)) if k % args.num_steps == 0:
for k in range(args.num_steps): init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
init_rstates.append((init_rstate1, init_rstate2))
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
main = next_to_play == main_player main = next_to_play == main_player
inference_time_start = time.time() inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \ cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action( next_rstate1, next_rstate2, action, logits, value, key = sample_action(
params, next_obs, next_rstate1, next_rstate2, main, next_done, key) params, next_obs, next_rstate1, next_rstate2, main, next_done, key)
cpu_action = np.array(action) cpu_action = np.array(action)
...@@ -453,6 +496,7 @@ def rollout( ...@@ -453,6 +496,7 @@ def rollout(
mains=cached_main, mains=cached_main,
actions=action, actions=action,
logits=logits, logits=logits,
values=value,
rewards=next_reward, rewards=next_reward,
next_dones=next_done, next_dones=next_done,
) )
...@@ -495,8 +539,29 @@ def rollout( ...@@ -495,8 +539,29 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start) rollout_time.append(time.time() - rollout_time_start)
partitioned_storage = prepare_data(storage) start_step = args.collect_steps - args.num_steps
storage = []
next_main = main_player == next_to_play
if args.collect_steps == args.num_steps:
storage_t = storage
storage = []
next_data = (next_obs, next_main)
else:
storage_t = storage[:args.num_steps]
storage = storage[args.num_steps:]
values, rewards, next_dones, mains = prepare_data([
(t.values, t.rewards, t.next_dones, t.mains) for t in storage])
next_value = sample_action(
params, next_obs, next_rstate1, next_rstate2, next_main, next_done, key)[-2]
next_value = jnp.where(next_main, next_value, -next_value)
adv_carry = compute_advantage_carry(
next_value, values, rewards, next_dones, mains)
next_data = adv_carry
partitioned_storage = jax.tree.map(
lambda x: jnp.split(x, len(learner_devices), axis=1), prepare_data(storage_t))
sharded_storage = [] sharded_storage = []
for x in partitioned_storage: for x in partitioned_storage:
if isinstance(x, dict): if isinstance(x, dict):
...@@ -508,10 +573,11 @@ def rollout( ...@@ -508,10 +573,11 @@ def rollout(
x = jax.device_put_sharded(x, devices=learner_devices) x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x) sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage) sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
init_rstate = init_rstates.pop(0)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, next_obs, next_main)) (init_rstate, next_data))
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
_start = time.time() _start = time.time()
...@@ -594,6 +660,8 @@ def main(): ...@@ -594,6 +660,8 @@ def main():
args.local_env_threads = args.local_env_threads or args.local_num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
if args.segment_length is not None: if args.segment_length is not None:
assert args.num_steps % args.segment_length == 0, "num_steps must be divisible by segment_length" assert args.num_steps % args.segment_length == 0, "num_steps must be divisible by segment_length"
args.collect_steps = args.collect_steps or args.num_steps
assert args.collect_steps >= args.num_steps, "collect_steps must be greater than or equal to num_steps"
if args.embedding_file: if args.embedding_file:
embeddings = load_embeddings(args.embedding_file, args.code_list_file) embeddings = load_embeddings(args.embedding_file, args.code_list_file)
...@@ -734,10 +802,10 @@ def main(): ...@@ -734,10 +802,10 @@ def main():
else: else:
eval_variables = None eval_variables = None
def advantage_fn( def compute_advantage(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value): actions, logits, rewards, next_v):
num_envs = jax.tree.leaves(next_value)[0].shape[0] num_envs = jax.tree.leaves(next_v)[0].shape[0]
num_steps = next_dones.shape[0] // num_envs num_steps = next_dones.shape[0] // num_envs
def reshape_time_series(x): def reshape_time_series(x):
...@@ -745,37 +813,20 @@ def main(): ...@@ -745,37 +813,20 @@ def main():
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
ratios = reshape_time_series(ratios)
new_values_, rewards, next_dones, switch_or_mains = jax.tree.map( new_values_, rewards, next_dones, switch_or_mains = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch_or_mains), reshape_time_series, (new_values, rewards, next_dones, switch_or_mains),
) )
# Advantages and target values target_values, advantages = advantage_fn(
if args.switch: args, next_v, new_values_, rewards, next_dones, switch_or_mains, ratios)
if args.value == "vtrace" or args.sep_value:
raise NotImplementedError
target_values, advantages = gae_sep_switch(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
# TODO: TD(lambda) for multi-step
ratios_ = reshape_time_series(ratios)
if args.value == "gae":
adv_fn = truncated_gae_sep if args.sep_value else truncated_gae
target_values, advantages = adv_fn(
next_value, new_values_, rewards, next_dones, switch_or_mains,
args.gamma, args.gae_lambda, args.upgo)
else:
adv_fn = vtrace_sep if args.sep_value else vtrace
target_values, advantages = adv_fn(
next_value, ratios_, new_values_, rewards, next_dones, switch_or_mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
target_values, advantages = jax.tree.map( target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages)) lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
return target_values, advantages return target_values, advantages
def loss_fn( def compute_loss(
new_logits, new_values, actions, logits, target_values, advantages, new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None): mask, num_steps=None):
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
...@@ -820,17 +871,24 @@ def main(): ...@@ -820,17 +871,24 @@ def main():
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, pg_loss, v_loss, ent_loss, approx_kl return loss, pg_loss, v_loss, ent_loss, approx_kl
def apply_fn(variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains): def apply_fn(
variables, obs, init_rstate, dones, next_dones, switch_or_mains, train=True):
if args.switch: if args.switch:
dones = dones | next_dones dones = dones | next_dones
((rstate1, rstate2), new_logits, new_values, _), state_updates = agent.apply( mutable = ["batch_stats"] if train else False
variables, obs, (rstate1, rstate2), dones, switch_or_mains, rets = agent.apply(
train=True, mutable=["batch_stats"]) variables, obs, init_rstate, dones, switch_or_mains,
train=train, mutable=mutable)
if train:
((rstate1, rstate2), new_logits, new_values, _), state_updates = rets
else:
(rstate1, rstate2), new_logits, new_values, _ = rets
state_updates = {}
new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values) new_values = jax.tree.map(lambda x: x.squeeze(-1), new_values)
return ((rstate1, rstate2), new_logits, new_values), state_updates return ((rstate1, rstate2), new_logits, new_values), state_updates
def compute_next_value( def compute_next_value(variables, next_rstate, next_obs, next_main):
variables, rstate1, rstate2, next_obs, next_main): rstate1, rstate2 = next_rstate
rstate = jax.tree.map( rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), rstate1, rstate2)
next_value = agent.apply(variables, next_obs, rstate)[2] next_value = agent.apply(variables, next_obs, rstate)[2]
...@@ -840,39 +898,39 @@ def main(): ...@@ -840,39 +898,39 @@ def main():
next_value = jnp.where(next_main, sign * next_value, -sign * next_value) next_value = jnp.where(next_main, sign * next_value, -sign * next_value)
return next_value return next_value
def compute_advantage( def get_advantage(
variables, rstate1, rstate2, obs, dones, next_dones, variables, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, next_obs, next_main): switch_or_mains, actions, logits, rewards, next_obs, next_main):
segment_length = dones.shape[0] num_steps = dones.shape[0]
obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \ obs, dones, next_dones, switch_or_mains, actions, logits, rewards = \
jax.tree.map( jax.tree.map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
(obs, dones, next_dones, switch_or_mains, actions, logits, rewards)) (obs, dones, next_dones, switch_or_mains, actions, logits, rewards))
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn( (next_rstate, new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains) variables, obs, init_rstate, dones, next_dones, switch_or_mains, train=False)
next_value = compute_next_value( next_value = compute_next_value(
variables, rstate1, rstate2, next_obs, next_main) variables, next_rstate, next_obs, next_main)
target_values, advantages = advantage_fn( target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value) actions, logits, rewards, next_value)
target_values, advantages = jax.tree.map( target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (segment_length, -1) + x.shape[2:]), lambda x: jnp.reshape(x, (num_steps, -1) + x.shape[2:]),
(target_values, advantages)) (target_values, advantages))
return target_values, advantages return target_values, advantages
def compute_loss( def get_loss(
params, batch_stats, rstate1, rstate2, obs, dones, next_dones, params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, target_values, advantages, mask): switch_or_mains, actions, logits, target_values, advantages, mask):
variables = {'params': params, 'batch_stats': batch_stats} variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn( ((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains) variables, obs, init_rstate, dones, next_dones, switch_or_mains)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn( loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages, new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=None) mask, num_steps=None)
...@@ -881,23 +939,27 @@ def main(): ...@@ -881,23 +939,27 @@ def main():
jax.lax.stop_gradient, (approx_kl, rstate1, rstate2)) jax.lax.stop_gradient, (approx_kl, rstate1, rstate2))
return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2) return loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)
def compute_advantage_loss( def get_advantage_loss(
params, batch_stats, rstate1, rstate2, obs, dones, next_dones, params, batch_stats, init_rstate, obs, dones, next_dones,
switch_or_mains, actions, logits, rewards, mask, next_obs, next_main): switch_or_mains, actions, logits, rewards, mask, next_data):
num_envs = jax.tree.leaves(next_main)[0].shape[0] num_envs = jax.tree.leaves(next_data)[0].shape[0]
variables = {'params': params, 'batch_stats': batch_stats} variables = {'params': params, 'batch_stats': batch_stats}
((rstate1, rstate2), new_logits, new_values), state_updates = apply_fn( (next_rstate, new_logits, new_values), state_updates = apply_fn(
variables, obs, rstate1, rstate2, dones, next_dones, switch_or_mains) variables, obs, init_rstate, dones, next_dones, switch_or_mains)
variables = {'params': params, 'batch_stats': state_updates['batch_stats']} if args.collect_steps == args.num_steps:
next_value = compute_next_value( next_obs, next_main = next_data
variables, rstate1, rstate2, next_obs, next_main) variables = {'params': params, 'batch_stats': state_updates['batch_stats']}
next_v = compute_next_value(
variables, next_rstate, next_obs, next_main)
else:
next_v = next_data
target_values, advantages = advantage_fn( target_values, advantages = compute_advantage(
new_logits, new_values, next_dones, switch_or_mains, new_logits, new_values, next_dones, switch_or_mains,
actions, logits, rewards, next_value) actions, logits, rewards, next_v)
loss, pg_loss, v_loss, ent_loss, approx_kl = loss_fn( loss, pg_loss, v_loss, ent_loss, approx_kl = compute_loss(
new_logits, new_values, actions, logits, target_values, advantages, new_logits, new_values, actions, logits, target_values, advantages,
mask, num_steps=dones.shape[0] // num_envs) mask, num_steps=dones.shape[0] // num_envs)
...@@ -908,18 +970,15 @@ def main(): ...@@ -908,18 +970,15 @@ def main():
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List, sharded_storages: List,
sharded_init_rstate1: List, sharded_init_rstate: List,
sharded_init_rstate2: List, sharded_next_data: List,
sharded_next_obs: List,
sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
): ):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs, init_rstate1, init_rstate2 = [ next_data, init_rstate = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_obs, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_data, sharded_init_rstate]
] ]
next_main = jnp.concatenate(sharded_next_main)
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second # main first, opponent second
...@@ -934,9 +993,10 @@ def main(): ...@@ -934,9 +993,10 @@ def main():
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage) storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if args.segment_length is None: if args.segment_length is None:
loss_grad_fn = jax.value_and_grad(compute_advantage_loss, has_aux=True) loss_grad_fn = jax.value_and_grad(get_advantage_loss, has_aux=True)
else: else:
loss_grad_fn = jax.value_and_grad(compute_loss, has_aux=True) # TODO: fix it
loss_grad_fn = jax.value_and_grad(get_loss, has_aux=True)
def update_epoch(carry, _): def update_epoch(carry, _):
agent_state, key = carry agent_state, key = carry
...@@ -947,9 +1007,9 @@ def main(): ...@@ -947,9 +1007,9 @@ def main():
return reshape_minibatch( return reshape_minibatch(
x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key) x, multi_step, args.num_minibatches, num_steps, args.segment_length, key=key)
b_init_rstate1, b_init_rstate2, b_next_obs, b_next_main = \ b_init_rstate, b_next_data = \
jax.tree.map(partial(convert_data, multi_step=False), jax.tree.map(partial(convert_data, multi_step=False),
(init_rstate1, init_rstate2, next_obs, next_main)) (init_rstate, next_data))
b_storage = jax.tree.map(convert_data, storage) b_storage = jax.tree.map(convert_data, storage)
if args.switch: if args.switch:
switch_or_mains = convert_data(switch) switch_or_mains = convert_data(switch)
...@@ -969,31 +1029,30 @@ def main(): ...@@ -969,31 +1029,30 @@ def main():
else: else:
def update_minibatch(carry, minibatch): def update_minibatch(carry, minibatch):
def update_minibatch_t(carry, minibatch_t): def update_minibatch_t(carry, minibatch_t):
agent_state, rstate1, rstate2 = carry agent_state, init_rstate = carry
minibatch_t = rstate1, rstate2, *minibatch_t minibatch_t = init_rstate, *minibatch_t
(loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, rstate1, rstate2)), \ (loss, (state_updates, pg_loss, v_loss, ent_loss, approx_kl, next_rstate)), \
grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t) grads = loss_grad_fn(agent_state.params, agent_state.batch_stats, *minibatch_t)
grads = jax.lax.pmean(grads, axis_name="local_devices") grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads) agent_state = agent_state.apply_gradients(grads=grads)
agent_state = agent_state.replace(batch_stats=state_updates['batch_stats']) agent_state = agent_state.replace(batch_stats=state_updates['batch_stats'])
return (agent_state, rstate1, rstate2), (loss, pg_loss, v_loss, ent_loss, approx_kl) return (agent_state, next_rstate), (loss, pg_loss, v_loss, ent_loss, approx_kl)
rstate1, rstate2, *minibatch_t, mask = minibatch init_rstate, *minibatch_t, mask = minibatch
target_values, advantages = compute_advantage( target_values, advantages = get_advantage(
get_variables(carry), rstate1, rstate2, *minibatch_t) get_variables(carry), init_rstate, *minibatch_t)
minibatch_t = *minibatch_t[:-2], target_values, advantages, mask minibatch_t = *minibatch_t[:-2], target_values, advantages, mask
(carry, _rstate1, _rstate2), \ (carry, _next_rstate), \
(loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan( (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch_t, (carry, rstate1, rstate2), minibatch_t) update_minibatch_t, (carry, init_rstate), minibatch_t)
return carry, (loss, pg_loss, v_loss, ent_loss, approx_kl) return carry, (loss, pg_loss, v_loss, ent_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan( agent_state, (loss, pg_loss, v_loss, ent_loss, approx_kl) = jax.lax.scan(
update_minibatch, update_minibatch,
agent_state, agent_state,
( (
b_init_rstate1, b_init_rstate,
b_init_rstate2,
b_storage.obs, b_storage.obs,
b_storage.dones, b_storage.dones,
b_storage.next_dones, b_storage.next_dones,
...@@ -1002,8 +1061,7 @@ def main(): ...@@ -1002,8 +1061,7 @@ def main():
b_storage.logits, b_storage.logits,
b_rewards, b_rewards,
b_mask, b_mask,
b_next_obs, b_next_data,
b_next_main,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, ent_loss, approx_kl)
...@@ -1023,7 +1081,7 @@ def main(): ...@@ -1023,7 +1081,7 @@ def main():
axis_name="main_devices", axis_name="main_devices",
devices=global_main_devices, devices=global_main_devices,
) )
multi_device_update = jax.pmap( multi_device_update = jax.pmap(
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
......
from functools import partial from functools import partial
from typing import NamedTuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -193,6 +194,14 @@ def vtrace_rnad( ...@@ -193,6 +194,14 @@ def vtrace_rnad(
return targets, q_estimate return targets, q_estimate
class VtraceCarry(NamedTuple):
v: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v, next_value, last_return, next_q, next_main = carry v, next_value, last_return, next_q, next_main = carry
ratio, cur_value, next_done, reward, main = inp ratio, cur_value, next_done, reward, main = inp
...@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q = reward + discount * next_value next_q = reward + discount * next_value
carry = v, cur_value, last_return, next_q, main carry = VtraceCarry(v, next_value, last_return, next_q, main)
return carry, (v, q_t, last_return) return carry, (v, q_t, last_return)
def vtrace( def vtrace(
next_value, ratios, values, rewards, next_dones, mains, next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False, gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
): ):
v = last_return = next_q = next_value if isinstance(next_v, (tuple, list)):
next_main = jnp.ones_like(next_value, dtype=jnp.bool_) carry = next_v
carry = v, next_value, last_return, next_q, next_main else:
next_value = next_v
_, (targets, q_estimate, return_t) = jax.lax.scan( v = last_return = next_q = next_value
next_main = jnp.ones_like(next_value, dtype=jnp.bool_)
carry = VtraceCarry(v, next_value, last_return, next_q, next_main)
carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max), partial(vtrace_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True carry, (ratios, values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
advantages = q_estimate - values advantages = q_estimate - values
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
...@@ -244,9 +260,18 @@ def vtrace( ...@@ -244,9 +260,18 @@ def vtrace(
return targets, advantages return targets, advantages
class VtraceSepCarry(NamedTuple):
v: jnp.ndarray
next_values: jnp.ndarray
reward: jnp.ndarray
xi: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ (v1, next_values1, reward1, xi1, last_return1, next_q1), \
last_return1, last_return2, next_q1, next_q2 = carry (v2, next_values2, reward2, xi2, last_return2, next_q2) = carry
ratio, cur_values, next_done, r_t, main = inp ratio, cur_values, next_done, r_t, main = inp
v1 = jnp.where(next_done, 0, v1) v1 = jnp.where(next_done, 0, v1)
...@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max): ...@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1 = jnp.where(main, 1, ratio * xi1) xi1 = jnp.where(main, 1, ratio * xi1)
xi2 = jnp.where(main, ratio * xi2, 1) xi2 = jnp.where(main, ratio * xi2, 1)
carry = v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2, \ carry1 = VtraceSepCarry(v1, next_values1, reward1, xi1, last_return1, next_q1)
last_return1, last_return2, next_q1, next_q2 carry2 = VtraceSepCarry(v2, next_values2, reward2, xi2, last_return2, next_q2)
return carry, (v, q_t, return_t) return (carry1, carry2), (v, q_t, return_t)
def vtrace_sep( def vtrace_sep(
next_value, ratios, values, rewards, next_dones, mains, next_v, ratios, values, rewards, next_dones, mains,
gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0, upgo=False, gamma, rho_min=0.001, rho_max=1.0, c_min=0.001, c_max=1.0,
upgo=False, return_carry=False
): ):
next_value1 = next_value if isinstance(next_v, (tuple, list)):
next_value2 = -next_value1 carry = next_v
v1 = return1 = next_q1 = next_value1 else:
v2 = return2 = next_q2 = next_value2 next_value = next_v
reward1 = reward2 = jnp.zeros_like(next_value) next_value1 = next_value
xi1 = xi2 = jnp.ones_like(next_value) carry1 = VtraceSepCarry(
carry = v1, v2, next_value1, next_value2, reward1, reward2, xi1, xi2, \ v=next_value1,
return1, return2, next_q1, next_q2 next_values=next_value1,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value1,
next_q=next_value1,
)
next_value2 = -next_value1
carry2 = VtraceSepCarry(
v=next_value2,
next_values=next_value2,
reward=jnp.zeros_like(next_value),
xi=jnp.ones_like(next_value),
last_return=next_value2,
next_q=next_value2,
)
carry = carry1, carry2
_, (targets, q_estimate, return_t) = jax.lax.scan( carry, (targets, q_estimate, return_t) = jax.lax.scan(
partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max), partial(vtrace_sep_loop, gamma=gamma, rho_min=rho_min, rho_max=rho_max, c_min=c_min, c_max=c_max),
carry, (ratios, values, next_dones, rewards, mains), reverse=True carry, (ratios, values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
advantages = q_estimate - values advantages = q_estimate - values
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
...@@ -325,9 +368,18 @@ def vtrace_sep( ...@@ -325,9 +368,18 @@ def vtrace_sep(
return targets, advantages return targets, advantages
class GAESepCarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
reward: jnp.ndarray
done_used: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ (lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1), \
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 = carry (lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2) = carry
cur_value, next_done, reward, main = inp cur_value, next_done, reward, main = inp
main1 = main main1 = main
main2 = ~main main2 = ~main
...@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda): ...@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1) lastgaelam1 = jnp.where(main1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2) lastgaelam2 = jnp.where(main2, lastgaelam2_, lastgaelam2)
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ carry1 = GAESepCarry(lastgaelam1, next_value1, reward1, done_used1, last_return1, next_q1)
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 carry2 = GAESepCarry(lastgaelam2, next_value2, reward2, done_used2, last_return2, next_q2)
return carry, (advantages, returns) return (carry1, carry2), (advantages, returns)
def truncated_gae_sep( def truncated_gae_sep(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, next_v, values, rewards, next_dones, mains, gamma, gae_lambda, upgo, return_carry=False):
): if isinstance(next_v, (tuple, list)):
next_value1 = next_value carry = next_v
next_value2 = -next_value1 else:
last_return1 = next_q1 = next_value1 next_value = next_v
last_return2 = next_q2 = next_value2 carry1 = GAESepCarry(
done_used1 = jnp.ones_like(next_dones[-1]) lastgaelam=jnp.zeros_like(next_value),
done_used2 = jnp.ones_like(next_dones[-1]) next_value=next_value,
reward1 = reward2 = jnp.zeros_like(next_value) reward=jnp.zeros_like(next_value),
lastgaelam1 = lastgaelam2 = jnp.zeros_like(next_value) done_used=jnp.ones_like(next_dones[-1]),
carry = lastgaelam1, lastgaelam2, next_value1, next_value2, reward1, reward2, \ last_return=next_value,
done_used1, done_used2, last_return1, last_return2, next_q1, next_q2 next_q=next_value,
)
_, (advantages, returns) = jax.lax.scan( carry2 = GAESepCarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=-next_value,
reward=jnp.zeros_like(next_value),
done_used=jnp.ones_like(next_dones[-1]),
last_return=-next_value,
next_q=-next_value,
)
carry = carry1, carry2
carry, (advantages, returns) = jax.lax.scan(
partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_sep_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
targets = values + advantages targets = values + advantages
if upgo: if upgo:
advantages += returns - values advantages += returns - values
...@@ -400,6 +463,14 @@ def truncated_gae_sep( ...@@ -400,6 +463,14 @@ def truncated_gae_sep(
return targets, advantages return targets, advantages
class GAECarry(NamedTuple):
lastgaelam: jnp.ndarray
next_value: jnp.ndarray
last_return: jnp.ndarray
next_q: jnp.ndarray
next_main: jnp.ndarray
def truncated_gae_loop(carry, inp, gamma, gae_lambda): def truncated_gae_loop(carry, inp, gamma, gae_lambda):
lastgaelam, next_value, last_return, next_q, next_main = carry lastgaelam, next_value, last_return, next_q, next_main = carry
cur_value, next_done, reward, main = inp cur_value, next_done, reward, main = inp
...@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda): ...@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q = reward + discount * next_value next_q = reward + discount * next_value
carry = lastgaelam, cur_value, last_return, next_q, main carry = GAECarry(lastgaelam, cur_value, last_return, next_q, main)
return carry, (lastgaelam, last_return) return carry, (lastgaelam, last_return)
def truncated_gae( def truncated_gae(
next_value, values, rewards, next_dones, mains, gamma, gae_lambda, upgo=False): next_v, values, rewards, next_dones, mains, gamma, gae_lambda,
lastgaelam = jnp.zeros_like(next_value) upgo=False, return_carry=False):
last_return = next_q = next_value if isinstance(next_v, (tuple, list)):
next_main = jnp.ones_like(next_value, dtype=jnp.bool_) carry = next_v
carry = lastgaelam, next_value, last_return, next_q, next_main else:
_, (advantages, return_t) = jax.lax.scan( next_value = next_v
carry = GAECarry(
lastgaelam=jnp.zeros_like(next_value),
next_value=next_value,
last_return=next_value,
next_q=next_value,
next_main=jnp.ones_like(next_value, dtype=jnp.bool_),
)
carry, (advantages, return_t) = jax.lax.scan(
partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda), partial(truncated_gae_loop, gamma=gamma, gae_lambda=gae_lambda),
carry, (values, next_dones, rewards, mains), reverse=True carry, (values, next_dones, rewards, mains), reverse=True
) )
if return_carry:
return carry
targets = values + advantages targets = values + advantages
if upgo: if upgo:
advantages += return_t - values advantages += return_t - values
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment