Commit 0e9969c5 authored by sbl1996@126.com's avatar sbl1996@126.com

Unify PPO and Impala to Cleanba

parent 907b51bc
This diff is collapsed.
...@@ -73,12 +73,12 @@ class Args: ...@@ -73,12 +73,12 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 32 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = True greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)""" """whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000 total_timesteps: int = 50000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 1e-3 learning_rate: float = 3e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
local_num_envs: int = 128 local_num_envs: int = 128
"""the number of parallel game environments""" """the number of parallel game environments"""
...@@ -92,12 +92,12 @@ class Args: ...@@ -92,12 +92,12 @@ class Args:
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
upgo: bool = False num_minibatches: int = 64
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
c_clip_min: float = 0.001 c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007 c_clip_max: float = 1.007
...@@ -141,9 +141,9 @@ class Args: ...@@ -141,9 +141,9 @@ class Args:
eval_checkpoint: Optional[str] = None eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate""" """the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32 local_eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 100
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# runtime arguments to be filled in # runtime arguments to be filled in
...@@ -193,6 +193,7 @@ class Transition(NamedTuple): ...@@ -193,6 +193,7 @@ class Transition(NamedTuple):
logits: list logits: list
rewards: list rewards: list
mains: list mains: list
next_dones: list
def create_agent(args, multi_step=False): def create_agent(args, multi_step=False):
...@@ -203,6 +204,7 @@ def create_agent(args, multi_step=False): ...@@ -203,6 +204,7 @@ def create_agent(args, multi_step=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels, lstm_channels=args.rnn_channels,
switch=False,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id, freeze_id=args.freeze_id,
) )
...@@ -373,6 +375,7 @@ def rollout( ...@@ -373,6 +375,7 @@ def rollout(
actions=action, actions=action,
logits=logits, logits=logits,
rewards=next_reward, rewards=next_reward,
next_dones=next_done,
) )
) )
...@@ -405,7 +408,7 @@ def rollout( ...@@ -405,7 +408,7 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
_start = time.time() _start = time.time()
...@@ -616,33 +619,36 @@ if __name__ == "__main__": ...@@ -616,33 +619,36 @@ if __name__ == "__main__":
return logits, value.squeeze(-1) return logits, value.squeeze(-1)
def loss_fn( def loss_fn(
params, rstate1, rstate2, obs, dones, mains, params, rstate1, rstate2, obs, dones, next_dones,
actions, logits, rewards, mask, next_value, next_done): mains, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb)) # (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0] num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs num_steps = dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones) mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask) n_valids = jnp.sum(mask)
inputs = (rstate1, rstate2, obs, dones, mains) inputs = (rstate1, rstate2, obs, dones, mains)
new_logits, new_values = get_logits_and_value(params, inputs) new_logits, new_values = get_logits_and_value(params, inputs)
new_logits, new_values, logits, actions, rewards, dones, mains, mask = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_logits, new_values, logits, actions, rewards, dones, mains, mask),
)
next_dones = jnp.concatenate([dones[1:], next_done[None, :]], axis=0)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
ratios_, new_values_, rewards, next_dones, mains = jax.tree.map(
reshape_time_series, (ratios, new_values, rewards, next_dones, mains),
)
# TODO: TD(lambda) for multi-step # TODO: TD(lambda) for multi-step
target_values, advantages = vtrace_2p0s( target_values, advantages = vtrace_2p0s(
next_value, ratios, new_values, rewards, next_dones, mains, args.gamma, next_value, ratios_, new_values_, rewards, next_dones, mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max) args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
logratio = jnp.log(ratios) target_values, advantages = jax.tree.map(
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
if args.ppo_clip: if args.ppo_clip:
pg_loss = clipped_surrogate_pg_loss( pg_loss = clipped_surrogate_pg_loss(
...@@ -671,7 +677,6 @@ if __name__ == "__main__": ...@@ -671,7 +677,6 @@ if __name__ == "__main__":
sharded_init_rstate1: List, sharded_init_rstate1: List,
sharded_init_rstate2: List, sharded_init_rstate2: List,
sharded_next_inputs: List, sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False, learn_opponent: bool = False,
...@@ -682,9 +687,7 @@ if __name__ == "__main__": ...@@ -682,9 +687,7 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
] ]
next_main, next_done = [ next_main = jnp.concatenate(sharded_next_main)
jnp.concatenate(x) for x in [sharded_next_main, sharded_next_done]
]
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second # main first, opponent second
...@@ -713,8 +716,8 @@ if __name__ == "__main__": ...@@ -713,8 +716,8 @@ if __name__ == "__main__":
return x return x
shuffled_init_rstate1, shuffled_init_rstate2, \ shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value, shuffled_next_done = jax.tree.map( shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value, next_done)) partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage = jax.tree.map( shuffled_storage = jax.tree.map(
partial(convert_data, num_steps=num_steps), storage) partial(convert_data, num_steps=num_steps), storage)
shuffled_mask = jnp.ones_like(shuffled_storage.mains) shuffled_mask = jnp.ones_like(shuffled_storage.mains)
...@@ -734,13 +737,13 @@ if __name__ == "__main__": ...@@ -734,13 +737,13 @@ if __name__ == "__main__":
shuffled_init_rstate2, shuffled_init_rstate2,
shuffled_storage.obs, shuffled_storage.obs,
shuffled_storage.dones, shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_storage.mains, shuffled_storage.mains,
shuffled_storage.actions, shuffled_storage.actions,
shuffled_storage.logits, shuffled_storage.logits,
shuffled_storage.rewards, shuffled_storage.rewards,
shuffled_mask, shuffled_mask,
shuffled_next_value, shuffled_next_value,
shuffled_next_done,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
...@@ -765,7 +768,7 @@ if __name__ == "__main__": ...@@ -765,7 +768,7 @@ if __name__ == "__main__":
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(8,), static_broadcasted_argnums=(7,),
) )
params_queues = [] params_queues = []
......
...@@ -74,12 +74,12 @@ class Args: ...@@ -74,12 +74,12 @@ class Args:
"""the maximum number of options""" """the maximum number of options"""
n_history_actions: int = 32 n_history_actions: int = 32
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = True greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)""" """whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000 total_timesteps: int = 50000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 1e-3 learning_rate: float = 3e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
local_num_envs: int = 128 local_num_envs: int = 128
"""the number of parallel game environments""" """the number of parallel game environments"""
...@@ -93,16 +93,16 @@ class Args: ...@@ -93,16 +93,16 @@ class Args:
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
gae_lambda: float = 0.95 num_minibatches: int = 64
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2 update_epochs: int = 2
"""the K epochs to update the policy""" """the K epochs to update the policy"""
norm_adv: bool = False norm_adv: bool = False
"""Toggles advantages normalization""" """Toggles advantages normalization"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
clip_coef: float = 0.25 clip_coef: float = 0.25
"""the surrogate clipping coefficient""" """the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = 3.0 dual_clip_coef: Optional[float] = 3.0
...@@ -113,7 +113,7 @@ class Args: ...@@ -113,7 +113,7 @@ class Args:
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0""" """the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 0.5 vf_coef: float = 1.0
"""coefficient of the value function""" """coefficient of the value function"""
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
...@@ -140,9 +140,9 @@ class Args: ...@@ -140,9 +140,9 @@ class Args:
eval_checkpoint: Optional[str] = None eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate""" """the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32 local_eval_episodes: int = 128
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 100
"""the number of iterations to evaluate the model""" """the number of iterations to evaluate the model"""
# runtime arguments to be filled in # runtime arguments to be filled in
...@@ -203,6 +203,7 @@ def create_agent(args, multi_step=False): ...@@ -203,6 +203,7 @@ def create_agent(args, multi_step=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels, lstm_channels=args.rnn_channels,
switch=True,
multi_step=multi_step, multi_step=multi_step,
freeze_id=args.freeze_id, freeze_id=args.freeze_id,
) )
...@@ -632,28 +633,30 @@ if __name__ == "__main__": ...@@ -632,28 +633,30 @@ if __name__ == "__main__":
num_envs = next_value.shape[0] num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs num_steps = dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones) mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask) n_valids = jnp.sum(mask)
real_dones = dones | next_dones dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch) inputs = (rstate1, rstate2, obs, dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs) new_logits, new_values = get_logits_and_value(params, inputs)
new_values_, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_values, rewards, next_dones, switch),
)
ratios = distrax.importance_sampling_ratios(distrax.Categorical( ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions) new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
new_values_, rewards, next_dones, switch = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch),
)
target_values, advantages = truncated_gae_2p0s( target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch, next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo) args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map( target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages)) lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv: if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8) advantages = masked_normalize(advantages, mask, eps=1e-8)
...@@ -699,9 +702,7 @@ if __name__ == "__main__": ...@@ -699,9 +702,7 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
] ]
next_main, = [ next_main = jnp.concatenate(sharded_next_main)
jnp.concatenate(x) for x in [sharded_next_main]
]
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second # main first, opponent second
...@@ -722,9 +723,7 @@ if __name__ == "__main__": ...@@ -722,9 +723,7 @@ if __name__ == "__main__":
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1) agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct next_value = jnp.where(next_main, -next_value, next_value)
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps): def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1: if args.update_epochs > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment