Commit 34f86ae4 authored by sbl1996@126.com's avatar sbl1996@126.com

Unify Impala and PPO

parent 9d8d4386
...@@ -165,7 +165,10 @@ if __name__ == "__main__": ...@@ -165,7 +165,10 @@ if __name__ == "__main__":
else: else:
with open(args.checkpoint2, "rb") as f: with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read()) params2 = flax.serialization.from_bytes(params, f.read())
params1 = jax.device_put(params1)
params2 = jax.device_put(params2)
@jax.jit @jax.jit
def get_probs(params, rstate, obs, done): def get_probs(params, rstate, obs, done):
agent = create_agent(args) agent = create_agent(args)
......
...@@ -17,6 +17,7 @@ import jax.numpy as jnp ...@@ -17,6 +17,7 @@ import jax.numpy as jnp
import numpy as np import numpy as np
import optax import optax
import rlax import rlax
import distrax
import tyro import tyro
from flax.training.train_state import TrainState from flax.training.train_state import TrainState
from rich.pretty import pprint from rich.pretty import pprint
...@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm ...@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...@@ -40,7 +42,9 @@ class Args: ...@@ -40,7 +42,9 @@ class Args:
log_frequency: int = 10 log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)""" """the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100 save_interval: int = 100
"""the frequency of saving the model""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments # Algorithm specific arguments
env_id: str = "YGOPro-v0" env_id: str = "YGOPro-v0"
...@@ -78,8 +82,6 @@ class Args: ...@@ -78,8 +82,6 @@ class Args:
"""the discount factor gamma""" """the discount factor gamma"""
num_minibatches: int = 4 num_minibatches: int = 4
"""the number of mini-batches""" """the number of mini-batches"""
gradient_accumulation_steps: int = 1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min: float = 0.001 c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007 c_clip_max: float = 1.007
...@@ -88,8 +90,6 @@ class Args: ...@@ -88,8 +90,6 @@ class Args:
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007 rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping""" """the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping""" """whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25 clip_coef: float = 0.25
...@@ -127,7 +127,6 @@ class Args: ...@@ -127,7 +127,6 @@ class Args:
# runtime arguments to be filled in # runtime arguments to be filled in
local_batch_size: int = 0 local_batch_size: int = 0
local_minibatch_size: int = 0 local_minibatch_size: int = 0
num_updates: int = 0
world_size: int = 0 world_size: int = 0
local_rank: int = 0 local_rank: int = 0
num_envs: int = 0 num_envs: int = 0
...@@ -218,34 +217,28 @@ def rollout( ...@@ -218,34 +217,28 @@ def rollout(
avg_win_rates = deque(maxlen=1000) avg_win_rates = deque(maxlen=1000)
@jax.jit @jax.jit
def apply_fn( def get_logits(
params: flax.core.FrozenDict, params: flax.core.FrozenDict, inputs):
next_obs, logits, value, _valid = create_agent(args).apply(params, inputs)[:2]
): return logits
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
def get_action( def get_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict, inputs):
next_obs, return get_logits(params, inputs).argmax(axis=1)
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
next_obs, next_obs, key: jax.random.PRNGKey):
key: jax.random.PRNGKey,
):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = apply_fn(params, next_obs)[0] logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick # sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape) u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
return next_obs, action, logits, key return next_obs, action, logits, key
# put data in the last index # put data in the last index
envs.async_reset() envs.async_reset()
...@@ -253,13 +246,13 @@ def rollout( ...@@ -253,13 +246,13 @@ def rollout(
rollout_time = deque(maxlen=10) rollout_time = deque(maxlen=10)
actor_policy_version = 0 actor_policy_version = 0
storage = [] storage = []
ai_player1 = np.concatenate([ main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player1) np.random.shuffle(main_player)
next_to_play = None next_to_play = None
learn = np.ones(args.local_num_envs, dtype=np.bool_) main = np.ones(args.local_num_envs, dtype=np.bool_)
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
...@@ -274,8 +267,7 @@ def rollout( ...@@ -274,8 +267,7 @@ def rollout(
inference_time = 0 inference_time = 0
env_time = 0 env_time = 0
num_steps_with_bootstrap = ( num_steps_with_bootstrap = (
args.num_steps + int(len(storage) == 0) args.num_steps + int(len(storage) == 0))
) # num_steps + 1 to get the states for value bootstrapping.
params_queue_get_time_start = time.time() params_queue_get_time_start = time.time()
if args.concurrency: if args.concurrency:
if update != 2: if update != 2:
...@@ -295,11 +287,11 @@ def rollout( ...@@ -295,11 +287,11 @@ def rollout(
_start = time.time() _start = time.time()
next_obs, next_reward, next_done, info = envs.recv() next_obs, next_reward, next_done, info = envs.recv()
next_reward = np.where(learn, next_reward, -next_reward) next_reward = np.where(main, next_reward, -next_reward)
env_time += time.time() - _start env_time += time.time() - _start
to_play = next_to_play to_play = next_to_play
next_to_play = info["to_play"] next_to_play = info["to_play"]
learn = next_to_play == ai_player1 main = next_to_play == main_player
inference_time_start = time.time() inference_time_start = time.time()
next_obs, action, logits, key = sample_action(params, next_obs, key) next_obs, action, logits, key = sample_action(params, next_obs, key)
...@@ -312,17 +304,17 @@ def rollout( ...@@ -312,17 +304,17 @@ def rollout(
Transition( Transition(
obs=next_obs, obs=next_obs,
dones=next_done, dones=next_done,
mains=main,
rewards=next_reward,
actions=action, actions=action,
logitss=logits, logitss=logits,
rewards=next_reward,
learns=learn,
) )
) )
for idx, d in enumerate(next_done): for idx, d in enumerate(next_done):
if not d: if not d:
continue continue
pl = 1 if to_play[idx] == ai_player1[idx] else -1 pl = 1 if to_play[idx] == main_player[idx] else -1
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
...@@ -488,7 +480,7 @@ if __name__ == "__main__": ...@@ -488,7 +480,7 @@ if __name__ == "__main__":
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
), ),
), ),
every_k_schedule=args.gradient_accumulation_steps, every_k_schedule=1,
) )
agent_state = TrainState.create( agent_state = TrainState.create(
apply_fn=None, apply_fn=None,
...@@ -505,13 +497,15 @@ if __name__ == "__main__": ...@@ -505,13 +497,15 @@ if __name__ == "__main__":
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
obs: np.ndarray, obs: np.ndarray,
): ):
logits, value, valid = create_agent(args).apply(params, obs) logits, value = create_agent(args).apply(params, obs)
return logits, value.squeeze(-1), valid return logits, value.squeeze(-1)
def impala_loss(params, obs, actions, logitss, rewards, dones, learns): def impala_loss(
params, obs, actions, logitss, rewards, dones, learns):
# (num_steps + 1, local_num_envs // n_mb)) # (num_steps + 1, local_num_envs // n_mb))
num_steps = actions.shape[0] - 1
discounts = (1.0 - dones) * args.gamma discounts = (1.0 - dones) * args.gamma
policy_logits, newvalue, valid = jax.vmap( policy_logits, newvalue = jax.vmap(
get_logits_and_value, in_axes=(None, 0))(params, obs) get_logits_and_value, in_axes=(None, 0))(params, obs)
newvalue = jnp.where(learns, newvalue, -newvalue) newvalue = jnp.where(learns, newvalue, -newvalue)
...@@ -527,19 +521,14 @@ if __name__ == "__main__": ...@@ -527,19 +521,14 @@ if __name__ == "__main__":
discounts = discounts[1:] discounts = discounts[1:]
mask = mask[:-1] mask = mask[:-1]
rhos = rlax.categorical_importance_sampling_ratios( rhos = distrax.importance_sampling_ratios(distrax.Categorical(
policy_logits, logitss, actions) policy_logits), distrax.Categorical(logitss), actions)
vtrace_fn = partial( vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max) vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap( vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)( vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, rhos) v_tm1, v_t, rewards, discounts, rhos)
jax.debug.print("R {}", jnp.where(dones[1:-1, :2], rewards[:-1, :2], 0).T)
jax.debug.print("E {}", jnp.where(dones[1:-1, :2], vtrace_returns.errors[:-1, :2] * 100, vtrace_returns.errors[:-1, :2]).T)
jax.debug.print("V {}", v_tm1[:-1, :2].T)
T = v_tm1.shape[0]
if args.upgo: if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)( advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1 rewards, v_t, discounts) - v_tm1
...@@ -548,13 +537,13 @@ if __name__ == "__main__": ...@@ -548,13 +537,13 @@ if __name__ == "__main__":
if args.ppo_clip: if args.ppo_clip:
pg_loss = jax.vmap( pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)( partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
rhos, advs, mask) * T rhos, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss) pg_loss = jnp.sum(pg_loss)
else: else:
pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs
pg_loss = jax.vmap( pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)( rlax.policy_gradient_loss, in_axes=1)(
policy_logits, actions, pg_advs, mask) * T policy_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss) pg_loss = jnp.sum(pg_loss)
baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask) baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask)
......
This diff is collapsed.
...@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter ...@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
...@@ -255,11 +255,7 @@ def rollout( ...@@ -255,11 +255,7 @@ def rollout(
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1) rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2) rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
# sample action: Gumbel-softmax trick action, key = categorical_sample(logits, key)
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
...@@ -329,7 +325,6 @@ def rollout( ...@@ -329,7 +325,6 @@ def rollout(
inference_time += time.time() - inference_time_start inference_time += time.time() - inference_time_start
_start = time.time() _start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action) next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"] next_to_play = info["to_play"]
env_time += time.time() - _start env_time += time.time() - _start
...@@ -338,11 +333,11 @@ def rollout( ...@@ -338,11 +333,11 @@ def rollout(
Transition( Transition(
obs=cached_next_obs, obs=cached_next_obs,
dones=cached_next_done, dones=cached_next_done,
mains=main,
actions=action, actions=action,
logprobs=logprob, logprobs=logprob,
rewards=next_reward,
mains=main,
probs=probs, probs=probs,
rewards=next_reward,
) )
) )
...@@ -359,8 +354,7 @@ def rollout( ...@@ -359,8 +354,7 @@ def rollout(
t.dones[idx] = True t.dones[idx] = True
t.rewards[idx] = -next_reward[idx] t.rewards[idx] = -next_reward[idx]
break break
pl = 1 if to_play[idx] == main_player[idx] else -1 episode_reward = info['r'][idx] * (1 if cur_main else -1)
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
avg_win_rates.append(win) avg_win_rates.append(win)
...@@ -387,16 +381,14 @@ def rollout( ...@@ -387,16 +381,14 @@ def rollout(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
learn_opponent = False learn_opponent = False
payload = ( payload = (
global_step, global_step,
actor_policy_version,
update, update,
sharded_storage, sharded_storage,
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
device_thread_id,
learn_opponent, learn_opponent,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -589,7 +581,6 @@ if __name__ == "__main__": ...@@ -589,7 +581,6 @@ if __name__ == "__main__":
pg_loss = jnp.maximum(pg_loss1, pg_loss2) pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid) pg_loss = masked_mean(pg_loss, valid)
# Value loss
v_loss = 0.5 * ((newvalue - target_values) ** 2) v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid) v_loss = masked_mean(v_loss, valid)
...@@ -600,10 +591,9 @@ if __name__ == "__main__": ...@@ -600,10 +591,9 @@ if __name__ == "__main__":
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List, sharded_storages: List,
sharded_next_obs: List,
sharded_next_rstate: List,
sharded_init_rstate1: List, sharded_init_rstate1: List,
sharded_init_rstate2: List, sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_done: List, sharded_next_done: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
...@@ -620,9 +610,9 @@ if __name__ == "__main__": ...@@ -620,9 +610,9 @@ if __name__ == "__main__":
return x return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs, next_rstate, init_rstate1, init_rstate2 = [ next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x) jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_obs, sharded_next_rstate, sharded_init_rstate1, sharded_init_rstate2] for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
] ]
next_done, next_main = [ next_done, next_main = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main] jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main]
...@@ -680,7 +670,7 @@ if __name__ == "__main__": ...@@ -680,7 +670,7 @@ if __name__ == "__main__":
values = values.reshape(storage.rewards.shape) values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, (next_rstate, next_obs))[2].squeeze(-1) agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct # TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0) sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value) next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
...@@ -745,7 +735,7 @@ if __name__ == "__main__": ...@@ -745,7 +735,7 @@ if __name__ == "__main__":
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(9,), static_broadcasted_argnums=(8,),
) )
params_queues = [] params_queues = []
...@@ -786,11 +776,9 @@ if __name__ == "__main__": ...@@ -786,11 +776,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
( (
global_step, global_step,
actor_policy_version,
update, update,
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
device_thread_id,
learn_opponent, learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
......
...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter ...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
...@@ -122,6 +122,8 @@ class Args: ...@@ -122,6 +122,8 @@ class Args:
thread_affinity: bool = False thread_affinity: bool = False
"""whether to use thread affinity for the environment""" """whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32 local_eval_episodes: int = 32
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 50
...@@ -198,12 +200,16 @@ def rollout( ...@@ -198,12 +200,16 @@ def rollout(
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
args: Args, args: Args,
rollout_queue, rollout_queue,
params_queue: queue.Queue, params_queue,
stats_queue, eval_queue,
writer, writer,
learner_devices, learner_devices,
device_thread_id, device_thread_id,
): ):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env( envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
...@@ -217,7 +223,7 @@ def rollout( ...@@ -217,7 +223,7 @@ def rollout(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot') args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -244,11 +250,23 @@ def rollout( ...@@ -244,11 +250,23 @@ def rollout(
rstate, logits = get_logits(params, inputs, done) rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1) return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key): next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main) main = jnp.array(main)
rstate = jax.tree.map( rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2) lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
...@@ -257,7 +275,7 @@ def rollout( ...@@ -257,7 +275,7 @@ def rollout(
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2) rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key) action, key = categorical_sample(logits, key)
return next_obs, rstate1, rstate2, action, logits, key return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index # put data in the last index
params_queue_get_time = deque(maxlen=10) params_queue_get_time = deque(maxlen=10)
...@@ -314,7 +332,8 @@ def rollout( ...@@ -314,7 +332,8 @@ def rollout(
main = next_to_play == main_player main = next_to_play == main_player
inference_time_start = time.time() inference_time_start = time.time()
cached_next_obs, next_rstate1, next_rstate2, action, logits, key = sample_action( cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key) params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action) cpu_action = np.array(action)
...@@ -329,7 +348,7 @@ def rollout( ...@@ -329,7 +348,7 @@ def rollout(
Transition( Transition(
obs=cached_next_obs, obs=cached_next_obs,
dones=cached_next_done, dones=cached_next_done,
mains=main, mains=cached_main,
actions=action, actions=action,
logits=logits, logits=logits,
rewards=next_reward, rewards=next_reward,
...@@ -412,19 +431,28 @@ def rollout( ...@@ -412,19 +431,28 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0] if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0: if device_thread_id != 0:
stats_queue.put(eval_return) eval_queue.put(eval_stat)
else: else:
eval_stats = [] eval_stats = []
eval_stats.append(eval_return) eval_stats.append(eval_stat)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(stats_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step) writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0: if device_thread_id == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}") print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time other_time += eval_time
...@@ -524,15 +552,24 @@ if __name__ == "__main__": ...@@ -524,15 +552,24 @@ if __name__ == "__main__":
params=params, params=params,
tx=tx, tx=tx,
) )
if args.checkpoint: if args.checkpoint:
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params) agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}") print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit @jax.jit
def get_logits_and_value( def get_logits_and_value(
params: flax.core.FrozenDict, inputs, params: flax.core.FrozenDict, inputs,
...@@ -711,7 +748,7 @@ if __name__ == "__main__": ...@@ -711,7 +748,7 @@ if __name__ == "__main__":
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
stats_queues = queue.Queue() eval_queues = queue.Queue()
dummy_writer = SimpleNamespace() dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None dummy_writer.add_scalar = lambda x, y, z: None
...@@ -721,7 +758,9 @@ if __name__ == "__main__": ...@@ -721,7 +758,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1)) params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params) if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread( threading.Thread(
target=rollout, target=rollout,
args=( args=(
...@@ -729,12 +768,13 @@ if __name__ == "__main__": ...@@ -729,12 +768,13 @@ if __name__ == "__main__":
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
stats_queues, eval_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
d_idx * args.num_actor_threads + thread_id, d_idx * args.num_actor_threads + thread_id,
), ),
).start() ).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10) rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10) data_transfer_time = deque(maxlen=10)
......
import numpy as np import numpy as np
def evaluate(envs, act_fn, params, rnn_state=None): def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
num_episodes = envs.num_envs
episode_lengths = [] episode_lengths = []
episode_rewards = [] episode_rewards = []
eval_win_rates = [] win_rates = []
obs = envs.reset()[0] obs = envs.reset()[0]
collected = np.zeros((num_episodes,), dtype=np.bool_) collected = np.zeros((num_episodes,), dtype=np.bool_)
while True: while True:
if rnn_state is None: if rnn_state is None:
actions = act_fn(params, obs) actions = predict_fn(obs)
else: else:
rnn_state, actions = act_fn(params, (rnn_state, obs)) rnn_state, actions = predict_fn((rnn_state, obs))
actions = np.array(actions) actions = np.array(actions)
obs, rewards, dones, info = envs.step(actions) obs, rewards, dones, info = envs.step(actions)
...@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None): ...@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None):
episode_lengths.append(episode_length) episode_lengths.append(episode_length)
episode_rewards.append(episode_reward) episode_rewards.append(episode_reward)
eval_win_rates.append(win) win_rates.append(win)
if len(episode_lengths) >= num_episodes: if len(episode_lengths) >= num_episodes:
break break
eval_return = np.mean(episode_rewards[:num_episodes]) eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes]) eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(eval_win_rates[:num_episodes]) eval_win_rate = np.mean(win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate return eval_return, eval_ep_len, eval_win_rate
\ No newline at end of file
def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
num_envs = envs.num_envs
episode_rewards = []
episode_lengths = []
win_rates = []
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state
while True:
main = next_to_play == main_player
rstate1, rstate2, actions = predict_fn(rstate1, rstate2, obs, main, dones)
actions = np.array(actions)
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
for idx, d in enumerate(dones):
if not d:
continue
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
if len(episode_lengths) >= num_episodes:
break
eval_return = np.mean(episode_rewards[:num_episodes])
eval_ep_len = np.mean(episode_lengths[:num_episodes])
eval_win_rate = np.mean(win_rates[:num_episodes])
return eval_return, eval_ep_len, eval_win_rate
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment