Commit 9670ed68 authored by sbl1996@126.com's avatar sbl1996@126.com

Refactor to main_player

parent 6b23ca2d
......@@ -180,14 +180,19 @@ if __name__ == "__main__":
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
def predict_fn(agent, obs):
def predict_fn(obs, main):
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs = get_probs(agent, obs)
if num_envs != 1:
probs1 = get_probs(agent, obs)
probs2 = get_probs(agent, obs)
probs = torch.where(main[:, None], probs1, probs2)
else:
if main[0]:
probs = get_probs(agent1, obs)
else:
probs = get_probs(agent2, obs)
probs = probs.cpu().numpy()
return probs
predict_fn1 = lambda obs: predict_fn(agent1, obs)
predict_fn2 = lambda obs: predict_fn(agent2, obs)
else:
import jax
import jax.numpy as jnp
......@@ -214,21 +219,30 @@ if __name__ == "__main__":
params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(
params: flax.core.FrozenDict,
next_obs,
):
logits = create_agent(args).apply(params, next_obs)[0]
def get_probs(params, obs):
agent = create_agent(args)
logits = agent.apply(params, obs)[0]
return jax.nn.softmax(logits, axis=-1)
def predict_fn(params, obs):
probs = get_probs(params, obs)
return np.array(probs)
predict_fn1 = lambda obs: predict_fn(params1, obs)
predict_fn2 = lambda obs: predict_fn(params2, obs)
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, obs, main):
probs1 = get_probs(params1, obs)
probs2 = get_probs(params2, obs)
probs = jnp.where(main[:, None], probs1, probs2)
return probs
def predict_fn(obs, main):
probs = get_probs2(params1, params2, obs, main)
return np.array(probs)
else:
def predict_fn(obs, main):
if main[0]:
probs = get_probs(params1, obs)
else:
probs = get_probs(params2, obs)
return np.array(probs)
obs, infos = envs.reset()
next_to_play = infos['to_play']
......@@ -241,7 +255,7 @@ if __name__ == "__main__":
start = time.time()
start_step = step
player1 = np.concatenate([
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
......@@ -254,15 +268,9 @@ if __name__ == "__main__":
model_time = env_time = 0
_start = time.time()
if args.num_envs != 1:
probs1 = predict_fn1(obs)
probs2 = predict_fn2(obs)
probs = np.where((next_to_play == player1)[:, None], probs1, probs2)
else:
if (next_to_play == player1).all():
probs = predict_fn1(obs)
else:
probs = predict_fn2(obs)
main = next_to_play == main_player
probs = predict_fn(obs, main)
actions = probs.argmax(axis=1)
model_time += time.time() - _start
......@@ -279,7 +287,7 @@ if __name__ == "__main__":
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == player1[idx] else -1
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0)
......@@ -292,7 +300,7 @@ if __name__ == "__main__":
# Only when num_envs=1, we switch the player here
if args.verbose:
player1 = 1 - player1
main_player = 1 - main_player
if len(episode_lengths) >= args.num_episodes:
break
......
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import tyro
import jax
import jax.numpy as jnp
import flax
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import PPOLSTMAgent
@dataclass
class Args:
seed: int = 1
"""the random seed"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device: Optional[str] = None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
compile: bool = False
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework: Optional[Literal["torch", "jax"]] = None
def create_agent(args):
return PPOLSTMAgent(
channels=args.num_channels,
num_layers=args.num_layers,
lstm_channels=args.rnn_channels,
embedding_shape=args.num_embeddings,
)
def init_rnn_state(num_envs, rnn_channels):
return (
np.zeros((num_envs, rnn_channels)),
np.zeros((num_envs, rnn_channels)),
)
if __name__ == "__main__":
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir(os.path.expanduser("~/.cache/jax"))
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
if not os.path.exists("replay"):
os.makedirs("replay")
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
if args.xla_device is not None:
os.environ.setdefault("JAX_PLATFORMS", args.xla_device)
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=-1,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
async_reset=False,
verbose=args.verbose,
record=args.record,
)
obs_space = envs.observation_space
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
agent = create_agent(args)
key = jax.random.PRNGKey(args.seed)
key, agent_key = jax.random.split(key, 2)
sample_obs = jax.tree_map(lambda x: jnp.array([x]), obs_space.sample())
rstate = init_rnn_state(1, args.rnn_channels)
params = jax.jit(agent.init)(agent_key, (rstate, sample_obs))
with open(args.checkpoint1, "rb") as f:
params1 = flax.serialization.from_bytes(params, f.read())
if args.checkpoint1 == args.checkpoint2:
params2 = params1
else:
with open(args.checkpoint2, "rb") as f:
params2 = flax.serialization.from_bytes(params, f.read())
@jax.jit
def get_probs(params, rstate, obs, done):
agent = create_agent(args)
next_rstate, logits = agent.apply(params, (rstate, obs))[:2]
probs = jax.nn.softmax(logits, axis=-1)
next_rstate = jax.tree_map(
lambda x: jnp.where(done[:, None], 0, x), next_rstate)
return next_rstate, probs
if args.num_envs != 1:
@jax.jit
def get_probs2(params1, params2, rstate1, rstate2, obs, main, done):
next_rstate1, probs1 = get_probs(params1, rstate1, obs, done)
next_rstate2, probs2 = get_probs(params2, rstate2, obs, done)
probs = jnp.where(main[:, None], probs1, probs2)
rstate1 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x1, x2), next_rstate1, rstate1)
rstate2 = jax.tree.map(
lambda x1, x2: jnp.where(main[:, None], x2, x1), next_rstate2, rstate2)
return rstate1, rstate2, probs
def predict_fn(rstate1, rstate2, obs, main, done):
rstate1, rstate2, probs = get_probs2(params1, params2, rstate1, rstate2, obs, main, done)
return rstate1, rstate2, np.array(probs)
else:
def predict_fn(rstate1, rstate2, obs, main, done):
if main[0]:
rstate1, probs = get_probs(params1, rstate1, obs, done)
else:
rstate2, probs = get_probs(params2, rstate2, obs, done)
return rstate1, rstate2, np.array(probs)
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs - num_envs // 2, dtype=np.int64)
])
rstate1 = rstate2 = init_rnn_state(num_envs, args.rnn_channels)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
_start = time.time()
main = next_to_play == main_player
rstate1, rstate2, probs = predict_fn(rstate1, rstate2, obs, main, dones)
actions = probs.argmax(axis=1)
model_time += time.time() - _start
to_play = next_to_play
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play = infos['to_play']
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == main_player[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = int(episode_reward > 0)
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
# Only when num_envs=1, we switch the player here
if args.verbose:
main_player = 1 - main_player
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
......@@ -312,7 +312,7 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
initial_rstate1, initial_rstate2 = jax.tree.map(
init_rstate1, init_rstate2 = jax.tree.map(
lambda x: x.copy(), (next_rstate1, next_rstate2))
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
......@@ -385,15 +385,10 @@ def rollout(
next_main = main_player == next_to_play
next_rstate = jax.tree.map(
lambda x1, x2: jnp.where(next_main[:, None], x1, x2), next_rstate1, next_rstate2)
# initial_rstate1: main, initial_rstate2: opponent
# init rstate1: == next_main, init rstate2: != next_main
init_rstate1 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], x, y), initial_rstate1, initial_rstate2)
init_rstate2 = jax.tree.map(
lambda x, y: jnp.where(next_main[:, None], y, x), initial_rstate1, initial_rstate2)
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_rstate, init_rstate1, init_rstate2, next_done, next_main))
learn_opponent = False
payload = (
global_step,
actor_policy_version,
......@@ -402,6 +397,7 @@ def rollout(
*sharded_data,
np.mean(params_queue_get_time),
device_thread_id,
learn_opponent,
)
rollout_queue.put(payload)
......@@ -565,9 +561,10 @@ if __name__ == "__main__":
return logprob, probs, entropy, value.squeeze(), valid
def ppo_loss(
params, inputs, actions, logprobs, probs, advantages, target_values):
params, inputs, actions, logprobs, probs, advantages, target_values, mask):
newlogprob, newprobs, entropy, newvalue, valid = \
get_logprob_entropy_value(params, inputs, actions)
valid = valid & mask
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
......@@ -600,7 +597,6 @@ if __name__ == "__main__":
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
@jax.jit
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
......@@ -611,11 +607,12 @@ if __name__ == "__main__":
sharded_next_done: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
def reshape_minibatch(x, num_minibatches, multi_step=False):
def reshape_minibatch(x, num_minibatches, num_steps=1):
N = num_minibatches
if multi_step:
x = jnp.reshape(x, (args.num_steps, N, -1) + x.shape[2:])
if num_steps > 1:
x = jnp.reshape(x, (num_steps, N, -1) + x.shape[2:])
x = x.transpose(1, 0, *range(2, x.ndim))
x = x.reshape(N, -1, *x.shape[3:])
else:
......@@ -632,21 +629,37 @@ if __name__ == "__main__":
]
# reorder storage of individual players
# main first, opponent second
num_steps, num_envs = storage.rewards.shape
T = jnp.arange(num_steps, dtype=jnp.int32)
B = jnp.arange(num_envs, dtype=jnp.int32)
mains = (storage.mains == next_main).astype(jnp.int32)
indices = jnp.argsort(T[:, None] + mains * num_steps, axis=0)
switch = T[:, None] == (num_steps - 1 - jnp.sum(mains, axis=0))
mains = storage.mains.astype(jnp.int32)
indices = jnp.argsort(T[:, None] - mains * num_steps, axis=0)
switch_steps = jnp.sum(mains, axis=0)
switch = T[:, None] == (switch_steps[None, :] - 1)
if not learn_opponent:
num_steps = int(num_steps * 0.75)
indices = indices[:num_steps + 1]
switch = switch[:num_steps]
storage = jax.tree.map(lambda x: x[indices, B[None, :]], storage)
if not learn_opponent:
next_obs = jax.tree.map(lambda x: x[num_steps], storage.obs)
next_done = storage.dones[num_steps]
next_main = storage.mains[num_steps]
storage = jax.tree.map(lambda x: x[:num_steps], storage)
# split minibatches for recompute values
n_mbs = args.num_minibatches // 4
num_minibatches = args.num_minibatches
if not learn_opponent:
num_minibatches = num_minibatches // 2
n_mbs = num_minibatches // 4
split_init_rstate = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs),
(init_rstate1, init_rstate2))
split_inputs = jax.tree.map(
partial(reshape_minibatch, num_minibatches=n_mbs, multi_step=True),
partial(reshape_minibatch, num_minibatches=n_mbs, num_steps=num_steps),
(storage.obs, storage.dones, switch))
split_inputs = split_init_rstate + split_inputs
......@@ -663,27 +676,32 @@ if __name__ == "__main__":
_, values = jax.lax.scan(
get_value_minibatch, agent_state, split_inputs)
values = values.reshape((n_mbs, args.num_steps, -1)).transpose(1, 0, 2)
values = values.reshape((n_mbs, num_steps, -1)).transpose(1, 0, 2)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(
agent_state.params, (next_rstate, next_obs))[2].squeeze(-1)
# TODO: check if this is correct
sign = jnp.where(switch_steps <= num_steps, 1.0, -1.0)
next_value = jnp.where(next_main, -sign * next_value, sign * next_value)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
next_value, next_done, values, storage.rewards, storage.dones, switch,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray, multi_step):
def convert_data(x: jnp.ndarray, num_steps):
x = jax.random.permutation(subkey, x, axis=1)
return reshape_minibatch(x, args.num_minibatches, multi_step)
return reshape_minibatch(x, num_minibatches, num_steps)
shuffled_init_rstate1, shuffled_init_rstate2 = jax.tree.map(
partial(convert_data, multi_step=False), (init_rstate1, init_rstate2))
partial(convert_data, num_steps=1), (init_rstate1, init_rstate2))
shuffled_storage, shuffled_switch, shuffled_advantages, shuffled_target_values = jax.tree.map(
partial(convert_data, multi_step=True), (storage, switch, advantages, target_values))
partial(convert_data, num_steps=num_steps), (storage, switch, advantages, target_values))
if learn_opponent:
shuffled_mask = jnp.ones_like(shuffled_storage.mains)
else:
shuffled_mask = shuffled_storage.mains
def update_minibatch(agent_state, minibatch):
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
......@@ -708,6 +726,7 @@ if __name__ == "__main__":
shuffled_storage.probs,
shuffled_advantages,
shuffled_target_values,
shuffled_mask,
),
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl)
......@@ -726,6 +745,7 @@ if __name__ == "__main__":
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(9,),
)
params_queues = []
......@@ -771,6 +791,7 @@ if __name__ == "__main__":
*sharded_data,
avg_params_queue_get_time,
device_thread_id,
learn_opponent,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
......@@ -779,6 +800,7 @@ if __name__ == "__main__":
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
......
......@@ -22,7 +22,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate
......@@ -242,6 +242,7 @@ def main():
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval()
if args.checkpoint:
......@@ -271,7 +272,6 @@ def main():
logits, value, valid = agent(next_obs)
return logits, value
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
......@@ -284,10 +284,11 @@ def main():
else:
traced_model_t = traced_model
train_step = torch.compile(train_step, mode=args.compile)
train_step = torch.compile(train_step_, mode=args.compile)
else:
traced_model = agent
traced_model_t = agent_t
train_step = train_step_
# ALGO Logic: Storage setup
obs = create_obs(obs_space, (args.collect_length, args.local_num_envs), device)
......@@ -310,12 +311,12 @@ def main():
next_to_play_ = info["to_play"]
next_to_play = to_tensor(next_to_play_, device)
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
main_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device, dtype=next_to_play.dtype)
np.random.shuffle(main_player_)
main_player = to_tensor(main_player_, device, dtype=next_to_play.dtype)
step = 0
for iteration in range(args.num_iterations):
......@@ -334,7 +335,7 @@ def main():
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learn = next_to_play == main_player
learns[step] = learn
_start = time.time()
......@@ -369,7 +370,7 @@ def main():
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
pl = 1 if to_play[idx] == main_player_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
......@@ -395,14 +396,13 @@ def main():
_start = time.time()
# bootstrap value if not done
with torch.no_grad():
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == ai_player1, value, -value)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != ai_player1, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
value = predict_step(traced_model, next_obs)[1].reshape(-1)
nextvalues1 = torch.where(next_to_play == main_player, value, -value)
if args.fix_target:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
nextvalues2 = torch.where(next_to_play != main_player, value_t, -value_t)
else:
nextvalues2 = -nextvalues1
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
......
......@@ -21,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.dist import reduce_gradidents, torchrun_setup, fprint
from ygoai.rl.buffer import create_obs
from ygoai.rl.ppo import bootstrap_value_selfplay
from ygoai.rl.ppo import bootstrap_value_selfplay, train_step as train_step_
from ygoai.rl.eval import evaluate
......@@ -261,6 +261,7 @@ def main():
embedding_shape = None
L = args.num_layers
agent = Agent(args.num_channels, L, L, embedding_shape).to(device)
torch.manual_seed(args.seed)
agent.eval()
if args.checkpoint:
......@@ -289,7 +290,6 @@ def main():
history = []
from ygoai.rl.ppo import train_step
if args.compile:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
......@@ -302,7 +302,10 @@ def main():
traced_model_t = torch.jit.optimize_for_inference(traced_model_t)
history.append(traced_model_t)
train_step = torch.compile(train_step, mode=args.compile)
train_step = torch.compile(train_step_, mode=args.compile)
else:
train_step = train_step_
def sample_target(history):
ts = []
......@@ -331,12 +334,12 @@ def main():
warmup_steps = 0
start_time = time.time()
next_done = torch.zeros(args.local_num_envs, device=device, dtype=torch.bool)
ai_player1_ = np.concatenate([
main_player_ = np.concatenate([
np.zeros(args.local_num_envs // 2, dtype=np.int64),
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1_)
ai_player1 = to_tensor(ai_player1_, device)
np.random.shuffle(main_player_)
main_player = to_tensor(main_player_, device)
next_value1 = next_value2 = 0
step = 0
ts = []
......@@ -374,7 +377,7 @@ def main():
for key in obs:
obs[key][step] = next_obs[key]
dones[step] = next_done
learn = next_to_play == ai_player1
learn = next_to_play == main_player
learns[step] = learn
_start = time.time()
......@@ -410,7 +413,7 @@ def main():
for idx, d in enumerate(next_done_):
if d:
pl = 1 if to_play[idx] == ai_player1_[idx] else -1
pl = 1 if to_play[idx] == main_player_[idx] else -1
episode_length = info['l'][idx]
episode_reward = info['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
......@@ -442,9 +445,9 @@ def main():
value = predict_step(traced_model, next_obs)[1].reshape(-1)
if not selfplay:
value_t = predict_step(traced_model_t, next_obs)[1].reshape(-1)
value = torch.where(next_to_play == ai_player1, value, value_t)
nextvalues1 = torch.where(next_to_play == ai_player1, value, next_value1)
nextvalues2 = torch.where(next_to_play != ai_player1, value, next_value2)
value = torch.where(next_to_play == main_player, value, value_t)
nextvalues1 = torch.where(next_to_play == main_player, value, next_value1)
nextvalues2 = torch.where(next_to_play != main_player, value, next_value2)
if step > 0 and iteration != 0:
# recalculate the values for the first few steps
......
......@@ -35,18 +35,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return (
observations,
......
......@@ -121,53 +121,6 @@ def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
# logits, newvalue, valid = agent(mb_obs)
# logits = logits - logits.logsumexp(dim=-1, keepdim=True)
# newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
# entropy = entropy_from_logits(logits)
# valid = torch.logical_and(valid, mb_learns)
# logratio = newlogprob - mb_logprobs
# ratio = logratio.exp()
# with torch.no_grad():
# # calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
# approx_kl = ((ratio - 1) - logratio).mean()
# clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
# if args.norm_adv:
# mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# # Policy loss
# pg_loss1 = -mb_advantages * ratio
# pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# pg_loss = torch.max(pg_loss1, pg_loss2)
# pg_loss = masked_mean(pg_loss, valid)
# # Value loss
# newvalue = newvalue.view(-1)
# if args.clip_vloss:
# v_loss_unclipped = (newvalue - mb_returns) ** 2
# v_clipped = mb_values + torch.clamp(
# newvalue - mb_values,
# -args.clip_coef,
# args.clip_coef,
# )
# v_loss_clipped = (v_clipped - mb_returns) ** 2
# v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
# v_loss = 0.5 * v_loss_max
# else:
# v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
# v_loss = masked_mean(v_loss, valid)
# entropy_loss = masked_mean(entropy, valid)
# loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
# loss.backward()
# optimizer.step()
# return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment