Commit 9670ed68 authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor to main_player

parent 6b23ca2d
...@@ -180,14 +180,19 @@ if __name__ == "__main__": ...@@ -180,14 +180,19 @@ if __name__ == "__main__":
agent1 = optimize_for_inference(agent1) agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2) agent2 = optimize_for_inference(agent2)
def predict_fn(agent, obs): def predict_fn(obs, main):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs) obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs = get_probs(agent, obs) if num_envs != 1:
probs1 = get_probs(agent, obs)
probs2 = get_probs(agent, obs)
probs = torch.where(main[:, None], probs1, probs2)
else:
if main[0]:
probs = get_probs(agent1, obs)
else:
probs = get_probs(agent2, obs)
probs = probs.cpu().numpy() probs = probs.cpu().numpy()
return probs return probs
predict_fn1 = lambda obs: predict_fn(agent1, obs)
predict_fn2 = lambda obs: predict_fn(agent2, obs)
else: else:
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -214,21 +219,30 @@ if __name__ == "__main__": ...@@ -214,21 +219,30 @@ if __name__ == "__main__":
params2 = flax.serialization.from_bytes(params, f.read()) params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit @jax.jit
def get_probs( def get_probs(params, obs):
params: flax.core.FrozenDict, agent = create_agent(args)
next_obs, logits = agent.apply(params, obs)[0]
):
logits = create_agent(args).apply(params, next_obs)[0]
return jax.nn.softmax(logits, axis=-1) return jax.nn.softmax(logits, axis=-1)
def predict_fn(params, obs): if args.num_envs != 1:
probs = get_probs(params, obs) @jax.jit
return np.array(probs) def get_probs2(params1, params2, obs, main):
probs1 = get_probs(params1, obs)
predict_fn1 = lambda obs: predict_fn(params1, obs) probs2 = get_probs(params2, obs)
predict_fn2 = lambda obs: predict_fn(params2, obs) probs = jnp.where(main[:, None], probs1, probs2)
return probs
def predict_fn(obs, main):
probs = get_probs2(params1, params2, obs, main)
return np.array(probs)
else:
def predict_fn(obs, main):
if main[0]:
probs = get_probs(params1, obs)
else:
probs = get_probs(params2, obs)
return np.array(probs)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
...@@ -241,7 +255,7 @@ if __name__ == "__main__": ...@@ -241,7 +255,7 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
player1 = np.concatenate([ main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64) np.ones(num_envs - num_envs // 2, dtype=np.int64)
]) ])
...@@ -254,15 +268,9 @@ if __name__ == "__main__": ...@@ -254,15 +268,9 @@ if __name__ == "__main__":
model_time = env_time = 0 model_time = env_time = 0
_start = time.time() _start = time.time()
if args.num_envs != 1:
probs1 = predict_fn1(obs) main = next_to_play == main_player
probs2 = predict_fn2(obs) probs = predict_fn(obs, main)
probs = np.where((next_to_play == player1)[:, None], probs1, probs2)
else:
if (next_to_play == player1).all():
probs = predict_fn1(obs)
else:
probs = predict_fn2(obs)
actions = probs.argmax(axis=1) actions = probs.argmax(axis=1)
model_time += time.time() - _start model_time += time.time() - _start
...@@ -279,7 +287,7 @@ if __name__ == "__main__": ...@@ -279,7 +287,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if d: if d:
win_reason = infos['win_reason'][idx] win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == player1[idx] else -1 pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0) win = int(episode_reward > 0)
...@@ -292,7 +300,7 @@ if __name__ == "__main__": ...@@ -292,7 +300,7 @@ if __name__ == "__main__":
# Only when num_envs=1, we switch the player here # Only when num_envs=1, we switch the player here
if args.verbose: if args.verbose:
player1 = 1 - player1 main_player = 1 - main_player
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
......
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import tyro
import jax
import jax.numpy as jnp
import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent
@dataclass
class Args:
seed: int = 1
"""the random seed"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def create_agent(args):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
if __name__ == "__main__":
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=-1,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(params, rstate, obs, done):
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree_map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs, done)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, done)
probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done):
rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs, main, done)
return rstate1, rstate2, np.array(probs)
else:
def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done)
else:
rstate2, probs = get_probs(params2, rstate2, obs, done)
return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
_start = time.time()
main = next_to_play == main_player
rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs, main, dones)
actions = probs.argmax(axis=1)
model_time += time.time() - _start
to_play = next_to_play
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# Only when num_envs=1, we switch the player here
if args.verbose:
main_player = 1 - main_player
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
...@@ -312,7 +312,7 @@ def rollout( ...@@ -312,7 +312,7 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start) params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time() rollout_time_start = time.time()
initial_rstate1, initial_rstate2 = jax.tree.map( init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2)) lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length): for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size global_step += args.local_num_envs * n_actors * args.world_size
...@@ -385,15 +385,10 @@ def rollout( ...@@ -385,15 +385,10 @@ def rollout(
next_main = main_player == next_to_play next_main = main_player == next_to_play
next_rstate = jax.tree.map( next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2) lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
# initial_rstate1: main, initial_rstate2: opponent
# init rstate1: == next_main, init rstate2: != next_main
init_rstate1 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], x, y), initial_rstate1, initial_rstate2)
init_rstate2 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], y, x), initial_rstate1, initial_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main)) (next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main))
learn_opponent = False
payload = ( payload = (
global_step, global_step,
actor_policy_version, actor_policy_version,
...@@ -402,6 +397,7 @@ def rollout( ...@@ -402,6 +397,7 @@ def rollout(
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
device_thread_id, device_thread_id,
learn_opponent,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -565,9 +561,10 @@ if __name__ == "__main__": ...@@ -565,9 +561,10 @@ if __name__ == "__main__":
return logprob, probs, entropy, value.squeeze(), valid return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss( def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values): params, inputs, actions, logprobs, probs, advantages, target_values, mask):
newlogprob, newprobs, entropy, newvalue, valid = \ newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions) get_logprob_entropy_value(params, inputs, actions)
valid = valid & mask
logratio = newlogprob - logprobs logratio = newlogprob - logprobs
ratio = jnp.exp(logratio) ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean() approx_kl = ((ratio - 1) - logratio).mean()
...@@ -600,7 +597,6 @@ if __name__ == "__main__": ...@@ -600,7 +597,6 @@ if __name__ == "__main__":
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update( def single_device_update(
agent_state: TrainState, agent_state: TrainState,
sharded_storages: List, sharded_storages: List,
...@@ -611,11 +607,12 @@ if __name__ == "__main__": ...@@ -611,11 +607,12 @@ if __name__ == "__main__":
sharded_next_done: List, sharded_next_done: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False,
): ):
def reshape_minibatch(x, num_minibatches, multi_step=False): def reshape_minibatch(x, num_minibatches, num_steps=1):
N = num_minibatches N = num_minibatches
if multi_step: if num_steps > 1:
x = jnp.reshape(x, (args.num_steps, N, -1) + x.shape[2:]) x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim)) x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:]) x = x.reshape(N, -1, *x.shape[3:])
else: else:
...@@ -632,21 +629,37 @@ if __name__ == "__main__": ...@@ -632,21 +629,37 @@ if __name__ == "__main__":
] ]
# reorder storage of individual players # reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32) T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32) B = jnp.arange(num_envs, dtype=jnp.int32)
mains = (storage.mains == next_main).astype(jnp.int32) mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] + mains * num_steps, axis=0) indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(mains, axis=0)) switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
if not learn_opponent:
num_steps = int(num_steps * 0.75)
indices = indices[:num_steps + 1]
switch = switch[:num_steps]
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage) storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if not learn_opponent:
next_obs = jax.tree.map(lambda x: x[num_steps], storage.obs)
next_done = storage.dones[num_steps]
next_main = storage.mains[num_steps]
storage = jax.tree.map(lambda x: x[:num_steps], storage)
# split minibatches for recompute values # split minibatches for recompute values
n_mbs = args.num_minibatches // 4 num_minibatches = args.num_minibatches
if not learn_opponent:
num_minibatches = num_minibatches // 2
n_mbs = num_minibatches // 4
split_init_rstate = jax.tree.map( split_init_rstate = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs), partial(reshape_minibatch, num_minibatches=n_mbs),
(init_rstate1, init_rstate2)) (init_rstate1, init_rstate2))
split_inputs = jax.tree.map( split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, multi_step=True), partial(reshape_minibatch, num_minibatches=n_mbs, num_steps=num_steps),
(storage.obs, storage.dones, switch)) (storage.obs, storage.dones, switch))
split_inputs = split_init_rstate + split_inputs split_inputs = split_init_rstate + split_inputs
...@@ -663,27 +676,32 @@ if __name__ == "__main__": ...@@ -663,27 +676,32 @@ if __name__ == "__main__":
_, values = jax.lax.scan( _, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs) get_value_minibatch, agent_state, split_inputs)
values = values.reshape((n_mbs, args.num_steps, -1)).transpose(1, 0, 2) values = values.reshape((n_mbs, num_steps, -1)).transpose(1, 0, 2)
values = values.reshape(storage.rewards.shape) values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply( next_value = create_agent(args).apply(
agent_state.params, (next_rstate, next_obs))[2].squeeze(-1) agent_state.params, (next_rstate, next_obs))[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn( advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch, next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda) args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray, multi_step): def convert_data(x: jnp.ndarray, num_steps):
x = jax.random.permutation(subkey, x, axis=1) x = jax.random.permutation(subkey, x, axis=1)
return reshape_minibatch(x, args.num_minibatches, multi_step) return reshape_minibatch(x, num_minibatches, num_steps)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map( shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2)) partial(convert_data, num_steps=1), (init_rstate1, init_rstate2))
shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map( shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map(
partial(convert_data, multi_step=True), (storage, switch, advantages, target_values)) partial(convert_data, num_steps=num_steps), (storage, switch, advantages, target_values))
if learn_opponent:
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
else:
shuffled_mask = shuffled_storage.mains
def update_minibatch(agent_state, minibatch): def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
...@@ -708,6 +726,7 @@ if __name__ == "__main__": ...@@ -708,6 +726,7 @@ if __name__ == "__main__":
shuffled_storage.probs, shuffled_storage.probs,
shuffled_advantages, shuffled_advantages,
shuffled_target_values, shuffled_target_values,
shuffled_mask,
), ),
) )
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl) return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
...@@ -726,6 +745,7 @@ if __name__ == "__main__": ...@@ -726,6 +745,7 @@ if __name__ == "__main__":
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(9,),
) )
params_queues = [] params_queues = []
...@@ -771,6 +791,7 @@ if __name__ == "__main__": ...@@ -771,6 +791,7 @@ if __name__ == "__main__":
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
device_thread_id, device_thread_id,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
...@@ -779,6 +800,7 @@ if __name__ == "__main__": ...@@ -779,6 +800,7 @@ if __name__ == "__main__":
agent_state, agent_state,
*list(zip(*sharded_data_list)), *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
learn_opponent,
) )
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
......
...@@ -22,7 +22,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings ...@@ -22,7 +22,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate from ygoai.rl.eval import evaluate
...@@ -242,6 +242,7 @@ def main(): ...@@ -242,6 +242,7 @@ def main():
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval() agent.eval()
if args.checkpoint: if args.checkpoint:
...@@ -271,7 +272,6 @@ def main(): ...@@ -271,7 +272,6 @@ def main():
logits, value, valid = agent(next_obs) logits, value, valid = agent(next_obs)
return logits, value return logits, value
from ygoai.rl.ppo import train_step
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile) # predict_step = torch.compile(predict_step, mode=args.compile)
...@@ -284,10 +284,11 @@ def main(): ...@@ -284,10 +284,11 @@ def main():
else: else:
traced_model_t = traced_model traced_model_t = traced_model
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step_, mode=args.compile)
else: else:
traced_model = agent traced_model = agent
traced_model_t = agent_t traced_model_t = agent_t
train_step = train_step_
# ALGO Logic: Storage setup # ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device) obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
...@@ -310,12 +311,12 @@ def main(): ...@@ -310,12 +311,12 @@ def main():
next_to_play_ = info["to_play"] next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device) next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([ main_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(main_player_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype) main_player = to_tensor(main_player_, device, dtype=next_to_play.dtype)
step = 0 step = 0
for iteration in range(args.num_iterations): for iteration in range(args.num_iterations):
...@@ -334,7 +335,7 @@ def main(): ...@@ -334,7 +335,7 @@ def main():
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
learn = next_to_play == ai_player1 learn = next_to_play == main_player
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
...@@ -369,7 +370,7 @@ def main(): ...@@ -369,7 +370,7 @@ def main():
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1 pl = 1 if to_play[idx] == main_player_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
...@@ -395,14 +396,13 @@ def main(): ...@@ -395,14 +396,13 @@ def main():
_start = time.time() _start = time.time()
# bootstrap value if not done # bootstrap value if not done
with torch.no_grad(): value = predict_step(traced_model, next_obs)[1].reshape(-1)
value = predict_step(traced_model, next_obs)[1].reshape(-1) nextvalues1 = torch.where(next_to_play == main_player, value, -value)
nextvalues1 = torch.where(next_to_play == ai_player1, value, -value) if args.fix_target:
if args.fix_target: value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1) nextvalues2 = torch.where(next_to_play != main_player, value_t, -value_t)
nextvalues2 = torch.where(next_to_play != ai_player1, value_t, -value_t) else:
else: nextvalues2 = -nextvalues1
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
......
...@@ -21,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings ...@@ -21,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate from ygoai.rl.eval import evaluate
...@@ -261,6 +261,7 @@ def main(): ...@@ -261,6 +261,7 @@ def main():
embedding_shape = None embedding_shape = None
L = args.num_layers L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device) agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval() agent.eval()
if args.checkpoint: if args.checkpoint:
...@@ -289,7 +290,6 @@ def main(): ...@@ -289,7 +290,6 @@ def main():
history = [] history = []
from ygoai.rl.ppo import train_step
if args.compile: if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here # It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile) # predict_step = torch.compile(predict_step, mode=args.compile)
...@@ -302,7 +302,10 @@ def main(): ...@@ -302,7 +302,10 @@ def main():
traced_model_t = torch.jit.optimize_for_inference(traced_model_t) traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t) history.append(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile) train_step = torch.compile(train_step_, mode=args.compile)
else:
train_step = train_step_
def sample_target(history): def sample_target(history):
ts = [] ts = []
...@@ -331,12 +334,12 @@ def main(): ...@@ -331,12 +334,12 @@ def main():
warmup_steps = 0 warmup_steps = 0
start_time = time.time() start_time = time.time()
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool) next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([ main_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64), np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64) np.ones(args.local_num_envs // 2, dtype=np.int64)
]) ])
np.random.shuffle(ai_player1_) np.random.shuffle(main_player_)
ai_player1 = to_tensor(ai_player1_, device) main_player = to_tensor(main_player_, device)
next_value1 = next_value2 = 0 next_value1 = next_value2 = 0
step = 0 step = 0
ts = [] ts = []
...@@ -374,7 +377,7 @@ def main(): ...@@ -374,7 +377,7 @@ def main():
for key in obs: for key in obs:
obs[key][step] = next_obs[key] obs[key][step] = next_obs[key]
dones[step] = next_done dones[step] = next_done
learn = next_to_play == ai_player1 learn = next_to_play == main_player
learns[step] = learn learns[step] = learn
_start = time.time() _start = time.time()
...@@ -410,7 +413,7 @@ def main(): ...@@ -410,7 +413,7 @@ def main():
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d: if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1 pl = 1 if to_play[idx] == main_player_[idx] else -1
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
...@@ -442,9 +445,9 @@ def main(): ...@@ -442,9 +445,9 @@ def main():
value = predict_step(traced_model, next_obs)[1].reshape(-1) value = predict_step(traced_model, next_obs)[1].reshape(-1)
if not selfplay: if not selfplay:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1) value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t) value = torch.where(next_to_play == main_player, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1) nextvalues1 = torch.where(next_to_play == main_player, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2) nextvalues2 = torch.where(next_to_play != main_player, value, next_value2)
if step > 0 and iteration != 0: if step > 0 and iteration != 0:
# recalculate the values for the first few steps # recalculate the values for the first few steps
......
...@@ -35,18 +35,6 @@ class RecordEpisodeStatistics(gym.Wrapper): ...@@ -35,18 +35,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_lengths *= 1 - dones self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return ( return (
observations, observations,
......
...@@ -121,53 +121,6 @@ def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b ...@@ -121,53 +121,6 @@ def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
# logits, newvalue, valid = agent(mb_obs)
# logits = logits - logits.logsumexp(dim=-1, keepdim=True)
# newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
# entropy = entropy_from_logits(logits)
# valid = torch.logical_and(valid, mb_learns)
# logratio = newlogprob - mb_logprobs
# ratio = logratio.exp()
# with torch.no_grad():
# # calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
# approx_kl = ((ratio - 1) - logratio).mean()
# clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
# if args.norm_adv:
# mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# # Policy loss
# pg_loss1 = -mb_advantages * ratio
# pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# pg_loss = torch.max(pg_loss1, pg_loss2)
# pg_loss = masked_mean(pg_loss, valid)
# # Value loss
# newvalue = newvalue.view(-1)
# if args.clip_vloss:
# v_loss_unclipped = (newvalue - mb_returns) ** 2
# v_clipped = mb_values + torch.clamp(
# newvalue - mb_values,
# -args.clip_coef,
# args.clip_coef,
# )
# v_loss_clipped = (v_clipped - mb_returns) ** 2
# v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
# v_loss = 0.5 * v_loss_max
# else:
# v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
# v_loss = masked_mean(v_loss, valid)
# entropy_loss = masked_mean(entropy, valid)
# loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
# loss.backward()
# optimizer.step()
# return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda): def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0) num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards) advantages = torch.zeros_like(rewards)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment