Commit d2dcecde authored by sbl1996@126.com's avatar sbl1996@126.com

PPO with LSTM as deault

parent 34f86ae4
...@@ -24,10 +24,10 @@ from rich.pretty import pprint ...@@ -24,10 +24,10 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import vtrace, upgo_return, clipped_surrogate_pg_loss from ygoai.rl.jax import upgo_return, vtrace, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...@@ -41,7 +41,7 @@ class Args: ...@@ -41,7 +41,7 @@ class Args:
"""seed of the experiment""" """seed of the experiment"""
log_frequency: int = 10 log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)""" """the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100 save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
"""the path to the model checkpoint to load""" """the path to the model checkpoint to load"""
...@@ -66,22 +66,26 @@ class Args: ...@@ -66,22 +66,26 @@ class Args:
total_timesteps: int = 5000000000 total_timesteps: int = 5000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
learning_rate: float = 3e-4 learning_rate: float = 1e-4
"""the learning rate of the optimizer""" """the learning rate of the optimizer"""
local_num_envs: int = 64 local_num_envs: int = 128
"""the number of parallel game environments""" """the number of parallel game environments"""
local_env_threads: Optional[int] = None local_env_threads: Optional[int] = None
"""the number of threads to use for environment""" """the number of threads to use for environment"""
num_actor_threads: int = 2 num_actor_threads: int = 2
"""the number of actor threads to use""" """the number of actor threads to use"""
num_steps: int = 20 num_steps: int = 32
"""the number of steps to run in each environment per policy rollout""" """the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks""" """Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0 gamma: float = 1.0
"""the discount factor gamma""" """the discount factor gamma"""
num_minibatches: int = 4 num_minibatches: int = 4
"""the number of mini-batches""" """the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
c_clip_min: float = 0.001 c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007 c_clip_max: float = 1.007
...@@ -90,6 +94,8 @@ class Args: ...@@ -90,6 +94,8 @@ class Args:
"""the minimum value of the importance sampling clipping""" """the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007 rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping""" """the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping""" """whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25 clip_coef: float = 0.25
...@@ -105,10 +111,12 @@ class Args: ...@@ -105,10 +111,12 @@ class Args:
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0]) actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use""" """the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [1]) learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use""" """the device ids that learner workers will use"""
distributed: bool = False distributed: bool = False
"""whether to use `jax.distirbuted`""" """whether to use `jax.distirbuted`"""
...@@ -166,18 +174,28 @@ class Transition(NamedTuple): ...@@ -166,18 +174,28 @@ class Transition(NamedTuple):
obs: list obs: list
dones: list dones: list
actions: list actions: list
logitss: list logits: list
rewards: list rewards: list
learns: list mains: list
next_dones: list
def create_agent(args): def create_agent(args, multi_step=False):
return PPOAgent( return PPOLSTMAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
) )
...@@ -185,8 +203,8 @@ def rollout( ...@@ -185,8 +203,8 @@ def rollout(
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
args: Args, args: Args,
rollout_queue, rollout_queue,
params_queue: queue.Queue, params_queue,
stats_queue, eval_queue,
writer, writer,
learner_devices, learner_devices,
device_thread_id, device_thread_id,
...@@ -218,41 +236,53 @@ def rollout( ...@@ -218,41 +236,53 @@ def rollout(
@jax.jit @jax.jit
def get_logits( def get_logits(
params: flax.core.FrozenDict, inputs): params: flax.core.FrozenDict, inputs, done):
logits, value, _valid = create_agent(args).apply(params, inputs)[:2] rstate, logits = create_agent(args).apply(params, inputs)[:2]
return logits rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action( def get_action(
params: flax.core.FrozenDict, inputs): params: flax.core.FrozenDict, inputs):
return get_logits(params, inputs).argmax(axis=1) batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
next_obs, key: jax.random.PRNGKey): next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = get_logits(params, next_obs) done = jnp.array(done)
# sample action: Gumbel-softmax trick main = jnp.array(main)
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution rstate = jax.tree.map(
key, subkey = jax.random.split(key) lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
u = jax.random.uniform(subkey, shape=logits.shape) rstate, logits = get_logits(params, (rstate, next_obs), done)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
return next_obs, action, logits, key rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
# put data in the last index action, key = categorical_sample(logits, key)
envs.async_reset() return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10) params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10) rollout_time = deque(maxlen=10)
actor_policy_version = 0 actor_policy_version = 0
storage = [] next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([ main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(main_player) np.random.shuffle(main_player)
next_to_play = None start_step = 0
main = np.ones(args.local_num_envs, dtype=np.bool_) storage = []
@jax.jit @jax.jit
def prepare_data(storage: List[Transition]) -> Transition: def prepare_data(storage: List[Transition]) -> Transition:
...@@ -266,8 +296,6 @@ def rollout( ...@@ -266,8 +296,6 @@ def rollout(
update_time_start = time.time() update_time_start = time.time()
inference_time = 0 inference_time = 0
env_time = 0 env_time = 0
num_steps_with_bootstrap = (
args.num_steps + int(len(storage) == 0))
params_queue_get_time_start = time.time() params_queue_get_time_start = time.time()
if args.concurrency: if args.concurrency:
if update != 2: if update != 2:
...@@ -282,47 +310,64 @@ def rollout( ...@@ -282,47 +310,64 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start) params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time() rollout_time_start = time.time()
for _ in range(0, num_steps_with_bootstrap): init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
_start = time.time() cached_next_obs = next_obs
next_obs, next_reward, next_done, info = envs.recv() cached_next_done = next_done
next_reward = np.where(main, next_reward, -next_reward)
env_time += time.time() - _start
to_play = next_to_play
next_to_play = info["to_play"]
main = next_to_play == main_player main = next_to_play == main_player
inference_time_start = time.time() inference_time_start = time.time()
next_obs, action, logits, key = sample_action(params, next_obs, key) cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action) cpu_action = np.array(action)
inference_time += time.time() - inference_time_start inference_time += time.time() - inference_time_start
envs.send(cpu_action) _start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append( storage.append(
Transition( Transition(
obs=next_obs, obs=cached_next_obs,
dones=next_done, dones=cached_next_done,
mains=main, mains=cached_main,
rewards=next_reward,
actions=action, actions=action,
logitss=logits, logits=logits,
rewards=next_reward,
next_dones=next_done,
) )
) )
for idx, d in enumerate(next_done): for idx, d in enumerate(next_done):
if not d: if not d:
continue continue
pl = 1 if to_play[idx] == main_player[idx] else -1 cur_main = main[idx]
episode_reward = info['r'][idx] * pl for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
avg_win_rates.append(win) avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start) rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage) partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = [] sharded_storage = []
for x in partitioned_storage: for x in partitioned_storage:
if isinstance(x, dict): if isinstance(x, dict):
...@@ -334,21 +379,25 @@ def rollout( ...@@ -334,21 +379,25 @@ def rollout(
x = jax.device_put_sharded(x, devices=learner_devices) x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x) sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage) sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = ( payload = (
global_step, global_step,
actor_policy_version,
update, update,
sharded_storage, sharded_storage,
*sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
device_thread_id, learn_opponent,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
# move bootstrapping step to the beginning of the next update
storage = storage[-1:]
if update % args.log_frequency == 0: if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns) if len(avg_ep_returns) > 0 else 0 avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths) avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time)) SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start)) SPS_update = int(args.batch_size / (time.time() - update_time_start))
...@@ -370,14 +419,14 @@ def rollout( ...@@ -370,14 +419,14 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0] eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if device_thread_id != 0: if device_thread_id != 0:
stats_queue.put(eval_return) eval_queue.put(eval_return)
else: else:
eval_stats = [] eval_stats = []
eval_stats.append(eval_return) eval_stats.append(eval_return)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(stats_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step) writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0: if device_thread_id == 0:
...@@ -388,21 +437,17 @@ def rollout( ...@@ -388,21 +437,17 @@ def rollout(
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
args.local_batch_size = int( args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids)) args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
args.local_minibatch_size = int(
args.local_batch_size // args.num_minibatches)
assert ( assert (
args.local_num_envs % len(args.learner_device_ids) == 0 args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)" ), "local_num_envs must be divisible by len(learner_device_ids)"
assert ( assert (
int(args.local_num_envs / len(args.learner_device_ids)) * int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed: if args.distributed:
jax.distributed.initialize( jax.distributed.initialize(
local_device_ids=range( local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
len(args.learner_device_ids) + len(args.actor_device_ids)),
) )
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids)))) print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
...@@ -411,13 +456,13 @@ if __name__ == "__main__": ...@@ -411,13 +456,13 @@ if __name__ == "__main__":
args.world_size = jax.process_count() args.world_size = jax.process_count()
args.local_rank = jax.process_index() args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * \ args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // ( args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices() local_devices = jax.local_devices()
global_devices = jax.devices() global_devices = jax.devices()
...@@ -429,8 +474,7 @@ if __name__ == "__main__": ...@@ -429,8 +474,7 @@ if __name__ == "__main__":
for d_id in args.learner_device_ids for d_id in args.learner_device_ids
] ]
print("global_learner_decices", global_learner_decices) print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [ args.global_learner_decices = [str(item) for item in global_learner_decices]
str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices] args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices] args.learner_devices = [str(item) for item in learner_devices]
pprint(args) pprint(args)
...@@ -441,8 +485,7 @@ if __name__ == "__main__": ...@@ -441,8 +485,7 @@ if __name__ == "__main__":
writer = SummaryWriter(f"runs/{run_name}") writer = SummaryWriter(f"runs/{run_name}")
writer.add_text( writer.add_text(
"hyperparameters", "hyperparameters",
"|param|value|\n|-|-|\n%s" % ( "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
) )
# seeding # seeding
...@@ -461,18 +504,19 @@ if __name__ == "__main__": ...@@ -461,18 +504,19 @@ if __name__ == "__main__":
obs_space = envs.observation_space obs_space = envs.observation_space
action_shape = envs.action_space.shape action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}") print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample()) sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close() envs.close()
del envs del envs
def linear_schedule(count): def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains # anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates # (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches)) / args.num_updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, sample_obs) params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps( tx = optax.MultiSteps(
optax.chain( optax.chain(
optax.clip_by_global_norm(args.max_grad_norm), optax.clip_by_global_norm(args.max_grad_norm),
...@@ -488,46 +532,55 @@ if __name__ == "__main__": ...@@ -488,46 +532,55 @@ if __name__ == "__main__":
tx=tx, tx=tx,
) )
agent_state = flax.jax_utils.replicate( if args.checkpoint:
agent_state, devices=learner_devices) with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
@jax.jit @jax.jit
def get_logits_and_value( def get_logits_and_value(
params: flax.core.FrozenDict, params: flax.core.FrozenDict, inputs,
obs: np.ndarray,
): ):
logits, value = create_agent(args).apply(params, obs) rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1) return logits, value.squeeze(-1)
def impala_loss( def ppo_loss(
params, obs, actions, logitss, rewards, dones, learns): params, rstate1, rstate2, obs, dones, next_dones,
# (num_steps + 1, local_num_envs // n_mb)) switch, actions, logits, rewards, mask, next_value):
num_steps = actions.shape[0] - 1 # (num_steps * local_num_envs // n_mb))
discounts = (1.0 - dones) * args.gamma num_envs = next_value.shape[0]
policy_logits, newvalue = jax.vmap( num_steps = dones.shape[0] // num_envs
get_logits_and_value, in_axes=(None, 0))(params, obs)
mask = mask & (~dones)
newvalue = jnp.where(learns, newvalue, -newvalue) n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
v_t = newvalue[1:] new_logits, v_tm1, logits, actions, rewards, next_dones, switch, mask = jax.tree.map(
# Remove bootstrap timestep from non-timesteps. lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
v_tm1 = newvalue[:-1] (new_logits, new_values, logits, actions, rewards, next_dones, switch, mask),
policy_logits = policy_logits[:-1] )
logitss = logitss[:-1]
actions = actions[:-1] v_t = jnp.concatenate([v_tm1[1:], next_value[None, :]], axis=0)
mask = 1.0 - dones discounts = (1.0 - next_dones) * args.gamma
rewards = rewards[1:] ratio = distrax.importance_sampling_ratios(distrax.Categorical(
discounts = discounts[1:] new_logits), distrax.Categorical(logits), actions)
mask = mask[:-1] logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
rhos = distrax.importance_sampling_ratios(distrax.Categorical(
policy_logits), distrax.Categorical(logitss), actions) # TODO: use switch to calculate the correct value
vtrace_fn = partial( vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max) vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap( vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)( vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, rhos) v_tm1, v_t, rewards, discounts, ratio)
if args.upgo: if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)( advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
...@@ -537,96 +590,140 @@ if __name__ == "__main__": ...@@ -537,96 +590,140 @@ if __name__ == "__main__":
if args.ppo_clip: if args.ppo_clip:
pg_loss = jax.vmap( pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)( partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
rhos, advs, mask) * num_steps ratio, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss) pg_loss = jnp.sum(pg_loss)
else: else:
pg_advs = jnp.minimum(args.rho_clip_max, rhos) * advs pg_advs = jnp.minimum(args.rho_clip_max, ratio) * advs
pg_loss = jax.vmap( pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)( rlax.policy_gradient_loss, in_axes=1)(
policy_logits, actions, pg_advs, mask) * num_steps new_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss) pg_loss = jnp.sum(pg_loss)
baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask) v_loss = 0.5 * (vtrace_returns.errors ** 2)
v_loss = jnp.sum(v_loss * mask)
ent_loss = jax.vmap(rlax.entropy_loss, in_axes=1)( entropy_loss = distrax.Softmax(new_logits).entropy()
policy_logits, mask) * T entropy_loss = jnp.sum(entropy_loss * mask)
ent_loss = jnp.sum(ent_loss)
n_samples = jnp.sum(mask) pg_loss = pg_loss / n_valids
pg_loss = pg_loss / n_samples v_loss = v_loss / n_valids
baseline_loss = baseline_loss / n_samples entropy_loss = entropy_loss / n_valids
ent_loss = ent_loss / n_samples
total_loss = pg_loss loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
total_loss += args.vf_coef * baseline_loss return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
total_loss += args.ent_coef * ent_loss
return total_loss, (pg_loss, baseline_loss, ent_loss)
@jax.jit
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List[Transition], sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False,
): ):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
impala_loss_grad_fn = jax.value_and_grad(impala_loss, has_aux=True) next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
def update_minibatch(agent_state, minibatch): for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
mb_obs, mb_actions, mb_logitss, mb_rewards, mb_dones, mb_learns = minibatch ]
(loss, (pg_loss, v_loss, entropy_loss)), grads = impala_loss_grad_fn( next_main, = [
agent_state.params, jnp.concatenate(x) for x in [sharded_next_main]
mb_obs, ]
mb_actions,
mb_logitss, # reorder storage of individual players
mb_rewards, # main first, opponent second
mb_dones, num_steps, num_envs = storage.rewards.shape
mb_learns, T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
) )
grads = jax.lax.pmean(grads, axis_name="local_devices") return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss) (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
n_mb = args.num_minibatches * args.gradient_accumulation_steps
storage_obs = {
k: jnp.array(jnp.split(v, n_mb, axis=1))
for k, v in storage.obs.items()
}
agent_state, (loss, pg_loss, v_loss, entropy_loss) = jax.lax.scan(
update_minibatch,
agent_state,
(
# (num_steps + 1, local_num_envs) => (n_mb, num_steps + 1, local_num_envs // n_mb)
storage_obs,
jnp.array(jnp.split(storage.actions, n_mb, axis=1)),
jnp.array(jnp.split(storage.logitss, n_mb, axis=1)),
jnp.array(jnp.split(storage.rewards, n_mb, axis=1)),
jnp.array(jnp.split(storage.dones, n_mb, axis=1)),
jnp.array(jnp.split(storage.learns, n_mb, axis=1)),
),
) )
loss = jax.lax.pmean(loss, axis_name="local_devices").mean() loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean() pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean() v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean( entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
entropy_loss, axis_name="local_devices").mean() approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, key return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap( multi_device_update = jax.pmap(
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(7,),
) )
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
stats_queues = queue.Queue() eval_queue = queue.Queue()
dummy_writer = SimpleNamespace() dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put( device_params = jax.device_put(unreplicated_params, local_devices[d_id])
unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1)) params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
...@@ -638,7 +735,7 @@ if __name__ == "__main__": ...@@ -638,7 +735,7 @@ if __name__ == "__main__":
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
stats_queues, eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
d_idx * args.num_actor_threads + thread_id, d_idx * args.num_actor_threads + thread_id,
...@@ -651,50 +748,47 @@ if __name__ == "__main__": ...@@ -651,50 +748,47 @@ if __name__ == "__main__":
while True: while True:
learner_policy_version += 1 learner_policy_version += 1
rollout_queue_get_time_start = time.time() rollout_queue_get_time_start = time.time()
sharded_storages = [] sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
( (
global_step, global_step,
actor_policy_version,
update, update,
sharded_storage, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
device_thread_id, learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage) sharded_data_list.append(sharded_data)
rollout_queue_get_time.append( rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
time.time() - rollout_queue_get_time_start)
training_time_start = time.time() training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, learner_keys) = multi_device_update( (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state, agent_state,
sharded_storages, *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
learn_opponent,
) )
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put( device_params = jax.device_put(unreplicated_params, local_devices[d_id])
unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready() device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes # record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0: if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
np.mean(rollout_queue_get_time), global_step)
writer.add_scalar( writer.add_scalar(
"stats/rollout_params_queue_get_time_diff", "stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time, np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step, global_step,
) )
writer.add_scalar("stats/training_time", writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
time.time() - training_time_start, global_step) writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/rollout_queue_size", writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size",
params_queues[-1].qsize(), global_step)
print( print(
global_step, global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}", f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
...@@ -702,28 +796,22 @@ if __name__ == "__main__": ...@@ -702,28 +796,22 @@ if __name__ == "__main__":
writer.add_scalar( writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step "charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
) )
writer.add_scalar("losses/value_loss", writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
v_loss[-1].item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
pg_loss[-1].item(), global_step) writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/entropy", writer.add_scalar("losses/loss", loss, global_step)
entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints/{run_name}" ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True) os.makedirs(ckpt_dir, exist_ok=True)
model_path = ckpt_dir + "/agent.cleanrl_model" M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f: with open(model_path, "wb") as f:
f.write( f.write(
flax.serialization.to_bytes( flax.serialization.to_bytes(unreplicated_params)
[
vars(args),
unreplicated_params,
]
)
) )
print(f"model saved to {model_path}") print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates: if learner_policy_version >= args.num_updates:
break break
...@@ -731,4 +819,4 @@ if __name__ == "__main__": ...@@ -731,4 +819,4 @@ if __name__ == "__main__":
if args.distributed: if args.distributed:
jax.distributed.shutdown() jax.distributed.shutdown()
writer.close() writer.close()
\ No newline at end of file
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, categorical_sample
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import upgo_return, vtrace, clipped_surrogate_pg_loss
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-4
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 32
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
c_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min: float = 0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max: float = 1.007
"""the maximum value of the importance sampling clipping"""
upgo: bool = False
"""whether to use UPGO for policy update"""
ppo_clip: bool = True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef: float = 0.25
"""the PPO surrogate clipping coefficient"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
return next_obs, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=main,
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def ppo_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
mask = mask & (~dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
new_logits, v_tm1, logits, actions, rewards, next_dones, switch, mask = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs) + x.shape[1:]),
(new_logits, new_values, logits, actions, rewards, next_dones, switch, mask),
)
v_t = jnp.concatenate([v_tm1[1:], next_value[None, :]], axis=0)
discounts = (1.0 - next_dones) * args.gamma
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
# TODO: use switch to calculate the correct value
vtrace_fn = partial(
vtrace, c_clip_min=args.c_clip_min, c_clip_max=args.c_clip_max, rho_clip_min=args.rho_clip_min, rho_clip_max=args.rho_clip_max)
vtrace_returns = jax.vmap(
vtrace_fn, in_axes=1, out_axes=1)(
v_tm1, v_t, rewards, discounts, ratio)
if args.upgo:
advs = jax.vmap(upgo_return, in_axes=1, out_axes=1)(
rewards, v_t, discounts) - v_tm1
else:
advs = vtrace_returns.q_estimate - v_tm1
if args.ppo_clip:
pg_loss = jax.vmap(
partial(clipped_surrogate_pg_loss, epsilon=args.clip_coef), in_axes=1)(
ratio, advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
else:
pg_advs = jnp.minimum(args.rho_clip_max, ratio) * advs
pg_loss = jax.vmap(
rlax.policy_gradient_loss, in_axes=1)(
new_logits, actions, pg_advs, mask) * num_steps
pg_loss = jnp.sum(pg_loss)
v_loss = 0.5 * (vtrace_returns.errors ** 2)
v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy()
entropy_loss = jnp.sum(entropy_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps):
if args.update_epochs > 1:
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
...@@ -16,15 +16,16 @@ import jax ...@@ -16,15 +16,16 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import optax import optax
import distrax
import tyro import tyro
from flax.training.train_state import TrainState from flax.training.train_state import TrainState
from rich.pretty import pprint from rich.pretty import pprint
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOAgent from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
...@@ -105,6 +106,8 @@ class Args: ...@@ -105,6 +106,8 @@ class Args:
"""the number of layers for the agent""" """the number of layers for the agent"""
num_channels: int = 128 num_channels: int = 128
"""the number of channels for the agent""" """the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1]) actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use""" """the device ids that actor workers will use"""
...@@ -119,6 +122,8 @@ class Args: ...@@ -119,6 +122,8 @@ class Args:
thread_affinity: bool = False thread_affinity: bool = False
"""whether to use thread affinity for the environment""" """whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32 local_eval_episodes: int = 32
"""the number of episodes to evaluate the model""" """the number of episodes to evaluate the model"""
eval_interval: int = 50 eval_interval: int = 50
...@@ -166,19 +171,28 @@ class Transition(NamedTuple): ...@@ -166,19 +171,28 @@ class Transition(NamedTuple):
obs: list obs: list
dones: list dones: list
actions: list actions: list
logprobs: list logits: list
rewards: list rewards: list
mains: list mains: list
probs: list next_dones: list
def create_agent(args): def create_agent(args, multi_step=False):
return PPOAgent( return PPOLSTMAgent(
channels=args.num_channels, channels=args.num_channels,
num_layers=args.num_layers, num_layers=args.num_layers,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32, param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
) )
...@@ -186,12 +200,16 @@ def rollout( ...@@ -186,12 +200,16 @@ def rollout(
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
args: Args, args: Args,
rollout_queue, rollout_queue,
params_queue: queue.Queue, params_queue,
stats_queue, eval_queue,
writer, writer,
learner_devices, learner_devices,
device_thread_id, device_thread_id,
): ):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env( envs = make_env(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
...@@ -205,7 +223,7 @@ def rollout( ...@@ -205,7 +223,7 @@ def rollout(
args, args,
args.seed + jax.process_index() + device_thread_id, args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes, args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot') args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs) eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids) len_actor_device_ids = len(args.actor_device_ids)
...@@ -219,30 +237,45 @@ def rollout( ...@@ -219,30 +237,45 @@ def rollout(
@jax.jit @jax.jit
def get_logits( def get_logits(
params: flax.core.FrozenDict, next_obs): params: flax.core.FrozenDict, inputs, done):
return create_agent(args).apply(params, next_obs)[0] rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action( def get_action(
params: flax.core.FrozenDict, next_obs): params: flax.core.FrozenDict, inputs):
return get_logits(params, next_obs).argmax(axis=1) batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit @jax.jit
def sample_action( def sample_action(
params: flax.core.FrozenDict, params: flax.core.FrozenDict,
next_obs, key: jax.random.PRNGKey): next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs) next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
logits = get_logits(params, next_obs) done = jnp.array(done)
# sample action: Gumbel-softmax trick main = jnp.array(main)
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution rstate = jax.tree.map(
key, subkey = jax.random.split(key) lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
u = jax.random.uniform(subkey, shape=logits.shape) rstate, logits = get_logits(params, (rstate, next_obs), done)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) action, key = categorical_sample(logits, key)
logits = logits.clip(min=jnp.finfo(logits.dtype).min) return next_obs, done, main, rstate1, rstate2, action, logits, key
probs = jax.nn.softmax(logits)
return next_obs, action, logprob, probs, key
# put data in the last index # put data in the last index
params_queue_get_time = deque(maxlen=10) params_queue_get_time = deque(maxlen=10)
...@@ -251,6 +284,10 @@ def rollout( ...@@ -251,6 +284,10 @@ def rollout(
next_obs, info = envs.reset() next_obs, info = envs.reset()
next_to_play = info["to_play"] next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_) next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([ main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
...@@ -285,6 +322,8 @@ def rollout( ...@@ -285,6 +322,8 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start) params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time() rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length): for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
...@@ -293,13 +332,14 @@ def rollout( ...@@ -293,13 +332,14 @@ def rollout(
main = next_to_play == main_player main = next_to_play == main_player
inference_time_start = time.time() inference_time_start = time.time()
cached_next_obs, action, logprob, probs, key = sample_action( cached_next_obs, cached_next_done, cached_main, \
params, cached_next_obs, key) next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action) cpu_action = np.array(action)
inference_time += time.time() - inference_time_start inference_time += time.time() - inference_time_start
_start = time.time() _start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action) next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"] next_to_play = info["to_play"]
env_time += time.time() - _start env_time += time.time() - _start
...@@ -308,11 +348,11 @@ def rollout( ...@@ -308,11 +348,11 @@ def rollout(
Transition( Transition(
obs=cached_next_obs, obs=cached_next_obs,
dones=cached_next_done, dones=cached_next_done,
mains=cached_main,
actions=action, actions=action,
logprobs=logprob, logits=logits,
rewards=next_reward, rewards=next_reward,
mains=main, next_dones=next_done,
probs=probs,
) )
) )
...@@ -322,15 +362,14 @@ def rollout( ...@@ -322,15 +362,14 @@ def rollout(
cur_main = main[idx] cur_main = main[idx]
for j in reversed(range(len(storage) - 1)): for j in reversed(range(len(storage) - 1)):
t = storage[j] t = storage[j]
if t.dones[idx]: if t.next_dones[idx]:
# For OTK where player may not switch # For OTK where player may not switch
break break
if t.mains[idx] != cur_main: if t.mains[idx] != cur_main:
t.dones[idx] = True t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx] t.rewards[idx] = -next_reward[idx]
break break
pl = 1 if to_play[idx] == main_player[idx] else -1 episode_reward = info['r'][idx] * (1 if cur_main else -1)
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward) avg_ep_returns.append(episode_reward)
avg_win_rates.append(win) avg_win_rates.append(win)
...@@ -353,17 +392,19 @@ def rollout( ...@@ -353,17 +392,19 @@ def rollout(
sharded_storage.append(x) sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage) sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = ( payload = (
global_step, global_step,
actor_policy_version,
update, update,
sharded_storage, sharded_storage,
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
device_thread_id, learn_opponent,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -390,19 +431,28 @@ def rollout( ...@@ -390,19 +431,28 @@ def rollout(
if args.eval_interval and update % args.eval_interval == 0: if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy # Eval with rule-based policy
_start = time.time() _start = time.time()
eval_return = evaluate(eval_envs, get_action, params)[0] if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0: if device_thread_id != 0:
stats_queue.put(eval_return) eval_queue.put(eval_stat)
else: else:
eval_stats = [] eval_stats = []
eval_stats.append(eval_return) eval_stats.append(eval_stat)
for _ in range(1, n_actors): for _ in range(1, n_actors):
eval_stats.append(stats_queue.get()) eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats) eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step) writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0: if device_thread_id == 0:
eval_time = time.time() - _start eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}") print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time other_time += eval_time
...@@ -485,8 +535,9 @@ if __name__ == "__main__": ...@@ -485,8 +535,9 @@ if __name__ == "__main__":
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args) agent = create_agent(args)
params = agent.init(agent_key, sample_obs) params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps( tx = optax.MultiSteps(
optax.chain( optax.chain(
optax.clip_by_global_norm(args.max_grad_norm), optax.clip_by_global_norm(args.max_grad_norm),
...@@ -501,6 +552,7 @@ if __name__ == "__main__": ...@@ -501,6 +552,7 @@ if __name__ == "__main__":
params=params, params=params,
tx=tx, tx=tx,
) )
if args.checkpoint: if args.checkpoint:
with open(args.checkpoint, "rb") as f: with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read()) params = flax.serialization.from_bytes(params, f.read())
...@@ -510,36 +562,62 @@ if __name__ == "__main__": ...@@ -510,36 +562,62 @@ if __name__ == "__main__":
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs)) # print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit @jax.jit
def get_logprob_entropy_value( def get_logits_and_value(
params: flax.core.FrozenDict, obs, actions, params: flax.core.FrozenDict, inputs,
): ):
logits, value, valid = create_agent(args).apply(params, obs) rstate, logits, value, valid = create_agent(
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions] args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss( def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values): params, rstate1, rstate2, obs, dones, next_dones,
newlogprob, newprobs, entropy, newvalue, valid = \ switch, actions, logits, rewards, mask, next_value):
get_logprob_entropy_value(params, inputs, actions) # (num_steps * local_num_envs // n_mb))
logratio = newlogprob - logprobs num_envs = next_value.shape[0]
ratio = jnp.exp(logratio) num_steps = dones.shape[0] // num_envs
approx_kl = ((ratio - 1) - logratio).mean()
mask = mask & (~dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
values, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs)),
(jax.lax.stop_gradient(new_values), rewards, next_dones, switch),
)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, values, rewards, next_dones, switch,
args.gamma, args.gae_lambda)
advantages, target_values = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (advantages, target_values))
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv: if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8) advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss # Policy loss
if args.spo_kld_max is not None: if args.spo_kld_max is not None:
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
eps = 1e-8 eps = 1e-8
kld = jnp.sum( kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1) probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max) kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps) d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio) d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
...@@ -550,85 +628,80 @@ if __name__ == "__main__": ...@@ -550,85 +628,80 @@ if __name__ == "__main__":
pg_loss1 = -advantages * ratio pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2) pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid) pg_loss = jnp.sum(pg_loss * mask)
v_loss = 0.5 * ((new_values - target_values) ** 2)
v_loss = jnp.sum(v_loss * mask)
# Value loss entropy_loss = distrax.Softmax(new_logits).entropy()
v_loss = 0.5 * ((newvalue - target_values) ** 2) entropy_loss = jnp.sum(entropy_loss * mask)
v_loss = masked_mean(v_loss, valid)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List, sharded_storages: List,
sharded_next_obs: List, sharded_init_rstate1: List,
sharded_next_done: List, sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False,
): ):
def reshape_minibatch(x, num_minibatches, multi_step=False):
N = num_minibatches
if multi_step:
x = jnp.reshape(x, (N, -1) + x.shape[2:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree.map(lambda *x: jnp.concatenate(x), *sharded_next_obs) next_inputs, init_rstate1, init_rstate2 = [
next_done, next_main = [ jax.tree.map(lambda *x: jnp.concatenate(x), *x)
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main] for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
] ]
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32) T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32) B = jnp.arange(num_envs, dtype=jnp.int32)
mains = (storage.mains == next_main).astype(jnp.int32) mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] + mains * num_steps, axis=0) indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(mains, axis=0)) switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage) storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
# split minibatches for recompute values
n_mbs = args.num_minibatches // 8
split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, multi_step=True), storage.obs)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _): def update_epoch(carry, _):
agent_state, key = carry agent_state, key = carry
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_inputs):
values = create_agent(args).apply(
agent_state.params, mb_inputs)[1].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, next_obs)[1].squeeze(-1) agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
advantages, target_values = compute_gae_fn( next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda) def convert_data(x: jnp.ndarray, num_steps):
advantages = advantages[:args.num_steps] if args.update_epochs > 1:
target_values = target_values[:args.num_steps] x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
def convert_data(x: jnp.ndarray): if num_steps > 1:
x = x.reshape(-1, *x.shape[2:]) x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = jax.random.permutation(subkey, x) x = x.transpose(1, 0, *range(2, x.ndim))
return reshape_minibatch(x, args.num_minibatches) x = x.reshape(N, -1, *x.shape[3:])
else:
shuffled_storage, shuffled_advantages, shuffled_target_values = jax.tree.map( x = jnp.reshape(x, (N, -1) + x.shape[1:])
convert_data, (storage, advantages, target_values)) return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
...@@ -641,12 +714,17 @@ if __name__ == "__main__": ...@@ -641,12 +714,17 @@ if __name__ == "__main__":
update_minibatch, update_minibatch,
agent_state, agent_state,
( (
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs, shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions, shuffled_storage.actions,
shuffled_storage.logprobs, shuffled_storage.logits,
shuffled_storage.probs, shuffled_storage.rewards,
shuffled_advantages, shuffled_mask,
shuffled_target_values, shuffled_next_value,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
...@@ -665,11 +743,12 @@ if __name__ == "__main__": ...@@ -665,11 +743,12 @@ if __name__ == "__main__":
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(7,),
) )
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
stats_queues = queue.Queue() eval_queue = queue.Queue()
dummy_writer = SimpleNamespace() dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None dummy_writer.add_scalar = lambda x, y, z: None
...@@ -679,7 +758,9 @@ if __name__ == "__main__": ...@@ -679,7 +758,9 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1)) params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params) if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread( threading.Thread(
target=rollout, target=rollout,
args=( args=(
...@@ -687,12 +768,13 @@ if __name__ == "__main__": ...@@ -687,12 +768,13 @@ if __name__ == "__main__":
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
stats_queues, eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
d_idx * args.num_actor_threads + thread_id, d_idx * args.num_actor_threads + thread_id,
), ),
).start() ).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10) rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10) data_transfer_time = deque(maxlen=10)
...@@ -705,11 +787,10 @@ if __name__ == "__main__": ...@@ -705,11 +787,10 @@ if __name__ == "__main__":
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
( (
global_step, global_step,
actor_policy_version,
update, update,
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
device_thread_id, learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
...@@ -718,6 +799,7 @@ if __name__ == "__main__": ...@@ -718,6 +799,7 @@ if __name__ == "__main__":
agent_state, agent_state,
*list(zip(*sharded_data_list)), *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
learn_opponent,
) )
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
......
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logprobs: list
rewards: list
mains: list
probs: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue: queue.Queue,
stats_queue,
writer,
learner_devices,
device_thread_id,
):
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode='bot')
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
return next_obs, rstate1, rstate2, action, logprob, probs, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, next_rstate1, next_rstate2, action, logprob, probs, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=main,
actions=action,
logprobs=logprob,
probs=probs,
rewards=next_reward,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
eval_return = evaluate(eval_envs, get_action, params, eval_rstate)[0]
if device_thread_id != 0:
stats_queue.put(eval_return)
else:
eval_stats = []
eval_stats.append(eval_return)
for _ in range(1, n_actors):
eval_stats.append(stats_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar("charts/eval_return", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
@jax.jit
def get_logprob_entropy_value(
params: flax.core.FrozenDict, inputs, actions,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
logprob = jax.nn.log_softmax(logits)[jnp.arange(actions.shape[0]), actions]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
probs = jax.nn.softmax(logits)
p_log_p = logits * probs
entropy = -p_log_p.sum(-1)
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values, mask):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions)
valid = valid & mask
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
advantages = masked_normalize(advantages, valid, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (newprobs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
v_loss = 0.5 * ((newvalue - target_values) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
def reshape_minibatch(x, num_minibatches, num_steps=1):
N = num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_done, next_main = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_main]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
if not learn_opponent:
num_steps = int(num_steps * 0.75)
indices = indices[:num_steps + 1]
switch = switch[:num_steps]
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if not learn_opponent:
next_obs = jax.tree.map(lambda x: x[num_steps], storage.obs)
next_done = storage.dones[num_steps]
next_main = storage.mains[num_steps]
storage = jax.tree.map(lambda x: x[:num_steps], storage)
# split minibatches for recompute values
num_minibatches = args.num_minibatches
if not learn_opponent:
num_minibatches = num_minibatches // 2
n_mbs = num_minibatches // 4
split_init_rstate = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs),
(init_rstate1, init_rstate2))
split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, num_steps=num_steps),
(storage.obs, storage.dones, switch))
split_inputs = split_init_rstate + split_inputs
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def get_value_minibatch(agent_state, mb_inputs):
values = create_agent(args, multi_step=True).apply(
agent_state.params, mb_inputs)[2].squeeze(-1)
return agent_state, values
_, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs)
values = values.reshape((n_mbs, num_steps, -1)).transpose(1, 0, 2)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
def convert_data(x: jnp.ndarray, num_steps):
x = jax.random.permutation(subkey, x, axis=1)
return reshape_minibatch(x, num_minibatches, num_steps)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2))
shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch, advantages, target_values))
if learn_opponent:
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
else:
shuffled_mask = shuffled_storage.mains
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_switch,
),
shuffled_storage.actions,
shuffled_storage.logprobs,
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
shuffled_mask,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(8,),
)
params_queues = []
rollout_queues = []
stats_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
stats_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
import os
import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import distrax
import tyro
from flax.training.train_state import TrainState
from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
embedding_file: Optional[str] = None
"""the embedding file for card embeddings"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"""the number of actor threads to use"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
spo_kld_max: Optional[float] = None
"""the maximum KLD for the SPO policy"""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [2, 3])
"""the device ids that learner workers will use"""
distributed: bool = False
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
eval_checkpoint: Optional[str] = None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size: int = 0
local_minibatch_size: int = 0
world_size: int = 0
local_rank: int = 0
num_envs: int = 0
batch_size: int = 0
minibatch_size: int = 0
num_updates: int = 0
global_learner_decices: Optional[List[str]] = None
actor_devices: Optional[List[str]] = None
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1):
if not args.thread_affinity:
thread_affinity_offset = -1
if thread_affinity_offset >= 0:
print("Binding to thread offset", thread_affinity_offset)
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=num_threads,
thread_affinity_offset=thread_affinity_offset,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
return envs
class Transition(NamedTuple):
obs: list
dones: list
actions: list
logits: list
rewards: list
mains: list
next_dones: list
def create_agent(args, multi_step=False):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
lstm_channels=args.rnn_channels,
multi_step=multi_step,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
def rollout(
key: jax.random.PRNGKey,
args: Args,
rollout_queue,
params_queue,
eval_queue,
writer,
learner_devices,
device_thread_id,
):
eval_mode = 'self' if args.eval_checkpoint else 'bot'
if eval_mode != 'bot':
eval_params = params_queue.get()
envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_num_envs,
args.local_env_threads,
thread_affinity_offset=device_thread_id * args.local_env_threads,
)
envs = RecordEpisodeStatistics(envs)
eval_envs = make_env(
args,
args.seed + jax.process_index() + device_thread_id,
args.local_eval_episodes,
args.local_eval_episodes // 4, mode=eval_mode)
eval_envs = RecordEpisodeStatistics(eval_envs)
len_actor_device_ids = len(args.actor_device_ids)
n_actors = args.num_actor_threads * len_actor_device_ids
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def get_logits(
params: flax.core.FrozenDict, inputs, done):
rstate, logits = create_agent(args).apply(params, inputs)[:2]
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
return rstate, logits
@jax.jit
def get_action(
params: flax.core.FrozenDict, inputs):
batch_size = jax.tree.leaves(inputs)[0].shape[0]
done = jnp.zeros(batch_size, dtype=jnp.bool_)
rstate, logits = get_logits(params, inputs, done)
return rstate, logits.argmax(axis=1)
@jax.jit
def get_action_battle(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, logits1 = get_logits(params1, (rstate1, obs), done)
next_rstate2, logits2 = get_logits(params2, (rstate2, obs), done)
logits = jnp.where(main[:, None], logits1, logits2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, logits.argmax(axis=1)
@jax.jit
def sample_action(
params: flax.core.FrozenDict,
next_obs, rstate1, rstate2, main, done, key):
next_obs = jax.tree.map(lambda x: jnp.array(x), next_obs)
done = jnp.array(done)
main = jnp.array(main)
rstate = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, logits = get_logits(params, (rstate, next_obs), done)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
actor_policy_version = 0
next_obs, info = envs.reset()
next_to_play = info["to_play"]
next_done = np.zeros(args.local_num_envs, dtype=np.bool_)
next_rstate1 = next_rstate2 = init_rnn_state(
args.local_num_envs, args.rnn_channels)
eval_rstate = init_rnn_state(
args.local_eval_episodes, args.rnn_channels)
main_player = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(main_player)
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
for update in range(1, args.num_updates + 2):
if update == 10:
start_time = time.time()
warmup_step = global_step
update_time_start = time.time()
inference_time = 0
env_time = 0
params_queue_get_time_start = time.time()
if args.concurrency:
if update != 2:
params = params_queue.get()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
main = next_to_play == main_player
inference_time_start = time.time()
cached_next_obs, cached_next_done, cached_main, \
next_rstate1, next_rstate2, action, logits, key = sample_action(
params, cached_next_obs, next_rstate1, next_rstate2, main, cached_next_done, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
_start = time.time()
next_obs, next_reward, next_done, info = envs.step(cpu_action)
next_to_play = info["to_play"]
env_time += time.time() - _start
storage.append(
Transition(
obs=cached_next_obs,
dones=cached_next_done,
mains=cached_main,
actions=action,
logits=logits,
rewards=next_reward,
next_dones=next_done,
)
)
for idx, d in enumerate(next_done):
if not d:
continue
cur_main = main[idx]
for j in reversed(range(len(storage) - 1)):
t = storage[j]
if t.next_dones[idx]:
# For OTK where player may not switch
break
if t.mains[idx] != cur_main:
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
avg_win_rates.append(win)
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
x = {
k: jax.device_put_sharded(v, devices=learner_devices)
for k, v in x.items()
}
else:
x = jax.device_put_sharded(x, devices=learner_devices)
sharded_storage.append(x)
sharded_storage = Transition(*sharded_storage)
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
)
rollout_queue.put(payload)
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_stat = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[0]
metric_name = "eval_return"
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_stat = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)[2]
metric_name = "eval_win_rate"
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.mean(eval_stats)
writer.add_scalar(f"charts/{metric_name}", eval_stats, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
args = tyro.cli(Args)
args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
assert (
args.local_num_envs % len(args.learner_device_ids) == 0
), "local_num_envs must be divisible by len(learner_device_ids)"
assert (
int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if args.distributed:
jax.distributed.initialize(
local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
)
print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
args.batch_size = args.local_batch_size * args.world_size
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
global_learner_decices = [
global_devices[d_id + process_index * len(local_devices)]
for process_index in range(args.world_size)
for d_id in args.learner_device_ids
]
print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices]
args.learner_devices = [str(item) for item in learner_devices]
pprint(args)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
learner_keys = jax.device_put_replicated(key, learner_devices)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree.map(lambda x: jnp.array([x]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
rstate = init_rnn_state(1, args.rnn_channels)
agent = create_agent(args)
params = agent.init(agent_key, (rstate, sample_obs))
tx = optax.MultiSteps(
optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
)
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
# print(agent.tabulate(agent_key, sample_obs))
if args.eval_checkpoint:
with open(args.eval_checkpoint, "rb") as f:
eval_params = flax.serialization.from_bytes(params, f.read())
print(f"loaded eval checkpoint from {args.eval_checkpoint}")
else:
eval_params = None
@jax.jit
def get_logits_and_value(
params: flax.core.FrozenDict, inputs,
):
rstate, logits, value, valid = create_agent(
args, multi_step=True).apply(params, inputs)
return logits, value.squeeze(-1)
def ppo_loss(
params, rstate1, rstate2, obs, dones, next_dones,
switch, actions, logits, rewards, mask, next_value):
# (num_steps * local_num_envs // n_mb))
num_envs = next_value.shape[0]
num_steps = dones.shape[0] // num_envs
mask = mask & (~dones)
n_valids = jnp.sum(mask)
real_dones = dones | next_dones
inputs = (rstate1, rstate2, obs, real_dones, switch)
new_logits, new_values = get_logits_and_value(params, inputs)
values, rewards, next_dones, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, num_envs)),
(jax.lax.stop_gradient(new_values), rewards, next_dones, switch),
)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, values, rewards, next_dones, switch,
args.gamma, args.gae_lambda)
advantages, target_values = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (advantages, target_values))
ratio = distrax.importance_sampling_ratios(distrax.Categorical(
new_logits), distrax.Categorical(logits), actions)
logratio = jnp.log(ratio)
approx_kl = (((ratio - 1) - logratio) * mask).sum() / n_valids
if args.norm_adv:
advantages = masked_normalize(advantages, mask, eps=1e-8)
# Policy loss
if args.spo_kld_max is not None:
probs = jax.nn.softmax(logits)
new_probs = jax.nn.softmax(new_logits)
eps = 1e-8
kld = jnp.sum(
probs * jnp.log((probs + eps) / (new_probs + eps)), axis=-1)
kld_clip = jnp.clip(kld, 0, args.spo_kld_max)
d_ratio = kld_clip / (kld + eps)
d_ratio = jnp.where(kld < 1e-6, 1.0, d_ratio)
sign_a = jnp.sign(advantages)
result = (d_ratio + sign_a - 1) * sign_a
pg_loss = -advantages * ratio * result
else:
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2)
pg_loss = jnp.sum(pg_loss * mask)
v_loss = 0.5 * ((new_values - target_values) ** 2)
v_loss = jnp.sum(v_loss * mask)
entropy_loss = distrax.Softmax(new_logits).entropy()
entropy_loss = jnp.sum(entropy_loss * mask)
pg_loss = pg_loss / n_valids
v_loss = v_loss / n_valids
entropy_loss = entropy_loss / n_valids
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_init_rstate1: List,
sharded_init_rstate2: List,
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
next_inputs, init_rstate1, init_rstate2 = [
jax.tree.map(lambda *x: jnp.concatenate(x), *x)
for x in [sharded_next_inputs, sharded_init_rstate1, sharded_init_rstate2]
]
next_main, = [
jnp.concatenate(x) for x in [sharded_next_main]
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
next_value = create_agent(args).apply(
agent_state.params, next_inputs)[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
def convert_data(x: jnp.ndarray, num_steps):
x = jax.random.permutation(subkey, x, axis=1 if num_steps > 1 else 0)
N = args.num_minibatches
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
x = jnp.reshape(x, (N, -1) + x.shape[1:])
return x
shuffled_init_rstate1, shuffled_init_rstate2, shuffled_next_value = jax.tree.map(
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2, next_value))
shuffled_storage, shuffled_switch = jax.tree.map(
partial(convert_data, num_steps=num_steps), (storage, switch))
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params, *minibatch)
grads = jax.lax.pmean(grads, axis_name="local_devices")
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_minibatch,
agent_state,
(
shuffled_init_rstate1,
shuffled_init_rstate2,
shuffled_storage.obs,
shuffled_storage.dones,
shuffled_storage.next_dones,
shuffled_switch,
shuffled_storage.actions,
shuffled_storage.logits,
shuffled_storage.rewards,
shuffled_mask,
shuffled_next_value,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
loss = jax.lax.pmean(loss, axis_name="local_devices").mean()
pg_loss = jax.lax.pmean(pg_loss, axis_name="local_devices").mean()
v_loss = jax.lax.pmean(v_loss, axis_name="local_devices").mean()
entropy_loss = jax.lax.pmean(entropy_loss, axis_name="local_devices").mean()
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
multi_device_update = jax.pmap(
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
params_queues = []
rollout_queues = []
eval_queues = queue.Queue()
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
if eval_params:
params_queues[-1].put(
jax.device_put(eval_params, local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queues[-1],
params_queues[-1],
eval_queues,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
),
).start()
params_queues[-1].put(device_params)
rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
learner_policy_version = 0
while True:
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_data_list = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
raise ValueError(f"loss is {loss}")
# record rewards for plotting purposes
if learner_policy_version % args.log_frequency == 0:
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
writer.add_scalar(
"stats/rollout_params_queue_get_time_diff",
np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
M_steps = args.batch_size * learner_policy_version // (2**20)
model_path = os.path.join(ckpt_dir, f"{timestamp}_{M_steps}M.flax_model")
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(unreplicated_params)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment