Commit 0e9969c5 authored by sbl1996@126.com's avatar sbl1996@126.com

Unify PPO and Impala to Cleanba

parent 907b51bc
This diff is collapsed.
......@@ -73,12 +73,12 @@ class Args:
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = True
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
......@@ -92,12 +92,12 @@ class Args:
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
......@@ -141,9 +141,9 @@ class Args:
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
local_eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
eval_interval: int = 100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
......@@ -193,6 +193,7 @@ class Transition(NamedTuple):
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
......@@ -203,6 +204,7 @@ def create_agent(args, multi_step=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
switch=False,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
......@@ -373,6 +375,7 @@ def rollout(
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
......@@ -405,7 +408,7 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
......@@ -616,33 +619,36 @@ if __name__ == "__main__":
return logits, value.squeeze(-1)
def loss_fn(
params, rstate1, rstate2, obs, dones, mains,
actions, logits, rewards, mask, next_value, next_done):
params, rstate1, rstate2, obs, dones, next_dones,
mains, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
inputs = (rstate1, rstate2, obs, dones, mains)
new_logits, new_values = get_logits_and_value(params, inputs)
new_logits, new_values, logits, actions, rewards, dones, mains, mask = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_logits, new_values, logits, actions, rewards, dones, mains, mask),
)
next_dones = jnp.concatenate([dones[1:], next_done[None, :]], axis=0)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
ratios_, new_values_, rewards, next_dones, mains = jax.tree.map(
reshape_time_series, (ratios, new_values, rewards, next_dones, mains),
)
# TODO: TD(lambda) for multi-step
target_values, advantages = vtrace_2p0s(
next_value, ratios, new_values, rewards, next_dones, mains, args.gamma,
next_value, ratios_, new_values_, rewards, next_dones, mains, args.gamma,
args.rho_clip_min, args.rho_clip_max, args.c_clip_min, args.c_clip_max)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
if args.ppo_clip:
pg_loss = clipped_surrogate_pg_loss(
......@@ -671,7 +677,6 @@ if __name__ == "__main__":
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
......@@ -682,9 +687,7 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, next_done = [
jnp.concatenate(x) for x in [sharded_next_main, sharded_next_done]
]
next_main = jnp.concatenate(sharded_next_main)
# reorder storage of individual players
# main first, opponent second
......@@ -713,8 +716,8 @@ if __name__ == "__main__":
return x
shuffled_init_rstate1, shuffled_init_rstate2, \
shuffled_next_value, shuffled_next_done = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value, next_done))
shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage = jax.tree.map(
partial(convert_data, num_steps=num_steps), storage)
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
......@@ -734,13 +737,13 @@ if __name__ == "__main__":
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_storage.mains,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
shuffled_next_done,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
......@@ -765,7 +768,7 @@ if __name__ == "__main__":
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(8,),
static_broadcasted_argnums=(7,),
)
params_queues = []
......
......@@ -74,12 +74,12 @@ class Args:
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = True
greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000
total_timesteps: int = 50000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
......@@ -93,16 +93,16 @@ class Args:
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
num_minibatches: int = 64
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
upgo: bool = True
"""Toggle the use of UPGO for advantages"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
dual_clip_coef: Optional[float] = 3.0
......@@ -113,7 +113,7 @@ class Args:
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
vf_coef: float = 1.0
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
......@@ -140,9 +140,9 @@ class Args:
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
local_eval_episodes: int = 128
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
eval_interval: int = 100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
......@@ -203,6 +203,7 @@ def create_agent(args, multi_step=False):
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
switch=True,
multi_step=multi_step,
freeze_id=args.freeze_id,
)
......@@ -632,28 +633,30 @@ if __name__ == "__main__":
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
def reshape_time_series(x):
return jnp.reshape(x, (num_steps, num_envs) + x.shape[1:])
mask = mask * (1.0 - dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
inputs = (rstate1, rstate2, obs, dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
new_values_, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_values, rewards, next_dones, switch),
)
ratios = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
new_values_, rewards, next_dones, switch = jax.tree.map(
reshape_time_series, (new_values, rewards, next_dones, switch),
)
target_values, advantages = truncated_gae_2p0s(
next_value, new_values_, rewards, next_dones, switch,
args.gamma, args.gae_lambda, args.upgo)
target_values, advantages = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (target_values, advantages))
logratio = jnp.log(ratios)
approx_kl = (((ratios - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
......@@ -699,9 +702,7 @@ if __name__ == "__main__":
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
]
next_main = jnp.concatenate(sharded_next_main)
# reorder storage of individual players
# main first, opponent second
......@@ -722,9 +723,7 @@ if __name__ == "__main__":
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
next_value = jnp.where(next_main, -next_value, next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment