Commit 34f86ae4 authored by sbl1996@126.com's avatar sbl1996@126.com

Unify Impala and PPO

parent 9d8d4386
......@@ -165,7 +165,10 @@ if __name__ == "__main__":
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
params1 = jax.device_put(params1)
params2 = jax.device_put(params2)
@jax.jit
def get_probs(params, rstate, obs, done):
agent = create_agent(args)
......
......@@ -17,6 +17,7 @@ import jax.numpy as jnp
import numpy as np
import optax
import rlax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
......@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -40,7 +42,9 @@ class Args:
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100
"""the frequency of saving the model"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
......@@ -78,8 +82,6 @@ class Args:
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
gradient_accumulation_steps: int = 1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
......@@ -88,8 +90,6 @@ class Args:
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
......@@ -127,7 +127,6 @@ class Args:
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
......@@ -218,34 +217,28 @@ def rollout(
avg_win_rates = deque(maxlen=1000)
@jax.jit
def apply_fn(
params: flax.core.FrozenDict,
next_obs,
):
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
def get_logits(
params: flax.core.FrozenDict, inputs):
logits, value, _valid = create_agent(args).apply(params, inputs)[:2]
return logits
def get_action(
params: flax.core.FrozenDict,
next_obs,
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
params: flax.core.FrozenDict, inputs):
return get_logits(params, inputs).argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs,
key: jax.random.PRNGKey,
):
next_obs, key: jax.random.PRNGKey):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0]
logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
return next_obs, action, logits, key
# put data in the last index
envs.async_reset()
......@@ -253,13 +246,13 @@ def rollout(
rollout_time = deque(maxlen=10)
actor_policy_version = 0
storage = []
ai_player1 = np.concatenate([
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
np.random.shuffle(main_player)
next_to_play = None
learn = np.ones(args.local_num_envs, dtype=np.bool_)
main = np.ones(args.local_num_envs, dtype=np.bool_)
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
......@@ -274,8 +267,7 @@ def rollout(
inference_time = 0
env_time = 0
num_steps_with_bootstrap = (
args.num_steps + int(len(storage) == 0)
) # num_steps + 1 to get the states for value bootstrapping.
args.num_steps + int(len(storage) == 0))
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
......@@ -295,11 +287,11 @@ def rollout(
_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
next_reward = np.where(learn, next_reward, -next_reward)
next_reward = np.where(main, next_reward, -next_reward)
env_time += time.time() - _start
to_play = next_to_play
next_to_play = info["to_play"]
learn = next_to_play == ai_player1
main = next_to_play == main_player
inference_time_start = time.time()
next_obs, action, logits, key = sample_action(params, next_obs, key)
......@@ -312,17 +304,17 @@ def rollout(
Transition(
obs=next_obs,
dones=next_done,
mains=main,
rewards=next_reward,
actions=action,
logitss=logits,
rewards=next_reward,
learns=learn,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
pl = 1 if to_play[idx] == ai_player1[idx] else -1
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
......@@ -488,7 +480,7 @@ if __name__ == "__main__":
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=args.gradient_accumulation_steps,
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
......@@ -505,13 +497,15 @@ if __name__ == "__main__":
params: flax.core.FrozenDict,
obs: np.ndarray,
):
logits, value, valid = create_agent(args).apply(params, obs)
return logits, value.squeeze(-1), valid
logits, value = create_agent(args).apply(params, obs)
return logits, value.squeeze(-1)
def impala_loss(params, obs, actions, logitss, rewards, dones, learns):
def impala_loss(
params, obs, actions, logitss, rewards, dones, learns):
# (num_steps + 1, local_num_envs // n_mb))
num_steps = actions.shape[0] - 1
discounts = (1.0 - dones) * args.gamma
policy_logits, newvalue, valid = jax.vmap(
policy_logits, newvalue = jax.vmap(
get_logits_and_value, in_axes=(None, 0))(params, obs)
newvalue = jnp.where(learns, newvalue, -newvalue)
......@@ -527,19 +521,14 @@ if __name__ == "__main__":
discounts = discounts[1:]
mask = mask[:-1]
rhos = rlax.categorical_importance_sampling_ratios(
policy_logits, logitss, actions)
rhos = distrax.importance_sampling_ratios(distrax.Categorical(
policy_logits), distrax.Categorical(logitss), actions)
vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, rhos)
jax.debug.print("R {}", jnp.where(dones[1:-1, :2], rewards[:-1, :2], 0).T)
jax.debug.print("E {}", jnp.where(dones[1:-1, :2], vtrace_returns.errors[:-1, :2] * 100, vtrace_returns.errors[:-1, :2]).T)
jax.debug.print("V {}", v_tm1[:-1, :2].T)
T = v_tm1.shape[0]
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
......@@ -548,13 +537,13 @@ if __name__ == "__main__":
if args.ppo_clip:
pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
rhos, advs, mask) * T
rhos, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
else:
pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs
pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)(
policy_logits, actions, pg_advs, mask) * T
policy_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask)
......
This diff is collapsed.
......@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
......@@ -255,11 +255,7 @@ def rollout(
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
action, key = categorical_sample(logits, key)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
......@@ -329,7 +325,6 @@ def rollout(
inference_time += time.time() - inference_time_start
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
......@@ -338,11 +333,11 @@ def rollout(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=main,
actions=action,
logprobs=logprob,
rewards=next_reward,
mains=main,
probs=probs,
rewards=next_reward,
)
)
......@@ -359,8 +354,7 @@ def rollout(
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_reward = info['r'][idx] * pl
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
......@@ -387,16 +381,14 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main))
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
learn_opponent = False
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
learn_opponent,
)
rollout_queue.put(payload)
......@@ -589,7 +581,6 @@ if __name__ == "__main__":
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
......@@ -600,10 +591,9 @@ if __name__ == "__main__":
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_rstate: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
......@@ -620,9 +610,9 @@ if __name__ == "__main__":
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs, next_rstate, init_rstate1, init_rstate2 = [
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_obs, sharded_next_rstate, sharded_init_rstate1, sharded_init_rstate2]
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_done, next_main = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main]
......@@ -680,7 +670,7 @@ if __name__ == "__main__":
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, (next_rstate, next_obs))[2].squeeze(-1)
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
......@@ -745,7 +735,7 @@ if __name__ == "__main__":
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(9,),
static_broadcasted_argnums=(8,),
)
params_queues = []
......@@ -786,11 +776,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
*sharded_data,
avg_params_queue_get_time,
device_thread_id,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
......
......@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
......@@ -122,6 +122,8 @@ class Args:
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
......@@ -198,12 +200,16 @@ def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
params_queue,
eval_queue,
writer,
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
......@@ -217,7 +223,7 @@ def rollout(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
......@@ -244,11 +250,23 @@ def rollout(
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
......@@ -257,7 +275,7 @@ def rollout(
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
return next_obs, rstate1, rstate2, action, logits, key
return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
......@@ -314,7 +332,8 @@ def rollout(
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, next_rstate1, next_rstate2, action, logits, key = sample_action(
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
......@@ -329,7 +348,7 @@ def rollout(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=main,
mains=cached_main,
actions=action,
logits=logits,
rewards=next_reward,
......@@ -412,19 +431,28 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0:
stats_queue.put(eval_return)
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_return)
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time
......@@ -524,15 +552,24 @@ if __name__ == "__main__":
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
......@@ -711,7 +748,7 @@ if __name__ == "__main__":
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
eval_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
......@@ -721,7 +758,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
......@@ -729,12 +768,13 @@ if __name__ == "__main__":
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
eval_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
......
import numpy as np
def evaluate(envs, act_fn, params, rnn_state=None):
num_episodes = envs.num_envs
def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
episode_lengths = []
episode_rewards = []
eval_win_rates = []
win_rates = []
obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_)
while True:
if rnn_state is None:
actions = act_fn(params, obs)
actions = predict_fn(obs)
else:
rnn_state, actions = act_fn(params, (rnn_state, obs))
rnn_state, actions = predict_fn((rnn_state, obs))
actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions)
......@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None):
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
eval_win_rates.append(win)
win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
eval_win_rate = np.mean(win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
num_envs = envs.num_envs
episode_rewards = []
episode_lengths = []
win_rates = []
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state
while True:
main = next_to_play == main_player
rstate1, rstate2, actions = predict_fn(rstate1, rstate2, obs, main, dones)
actions = np.array(actions)
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
for idx, d in enumerate(dones):
if not d:
continue
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment