Commit 096e743e authored by sbl1996@126.com's avatar sbl1996@126.com

Update agent and UPGO

parent 4f2ad15b
......@@ -3,11 +3,11 @@ import queue
import random
import threading
import time
from datetime import datetime, timedelta, timezone
from collections import deque
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional
from functools import partial
import ygoenv
import flax
......@@ -21,9 +21,11 @@ from rich.pretty import pprint
from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent import PPOAgent
from ygoai.rl.jax.agent2 import PPOAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_normalize
from ygoai.rl.jax.eval import evaluate
from ygoai.rl.jax import compute_gae, compute_gae_upgo
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -31,13 +33,13 @@ os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_paralleli
@dataclass
class Args:
exp_name: str = os.path.basename(__file__).rstrip(".py")
"the name of this experiment"
"""the name of this experiment"""
seed: int = 1
"seed of the experiment"
save_model: bool = False
"whether to save model into the `runs/{run_name}` folder"
log_frequency: int = 2
"the logging frequency of the model performance (in terms of `updates`)"
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval: int = 100
"""the frequency of saving the model"""
# Algorithm specific arguments
env_id: str = "YGOPro-v0"
......@@ -57,40 +59,42 @@ class Args:
n_history_actions: int = 32
"""the number of history actions to use"""
total_timesteps: int = 50000000
"total timesteps of the experiments"
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
learning_rate: float = 1e-3
"the learning rate of the optimizer"
"""the learning rate of the optimizer"""
local_num_envs: int = 128
"the number of parallel game environments"
"""the number of parallel game environments"""
local_env_threads: Optional[int] = None
"the number of threads to use for environment"
"""the number of threads to use for environment"""
num_actor_threads: int = 2
"the number of actor threads to use"
"""the number of actor threads to use"""
num_steps: int = 128
"the number of steps to run in each environment per policy rollout"
anneal_lr: bool = True
"Toggle learning rate annealing for policy and value networks"
"""the number of steps to run in each environment per policy rollout"""
collect_length: Optional[int] = None
"""the number of steps to compute the advantages"""
anneal_lr: bool = False
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 1.0
"the discount factor gamma"
gae_lambda: float = 0.98
"the lambda for the general advantage estimation"
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
upgo: bool = False
"""Toggle the use of UPGO for advantages"""
num_minibatches: int = 8
"the number of mini-batches"
gradient_accumulation_steps: int = 1
"the number of gradient accumulation steps before performing an optimization step"
"""the number of mini-batches"""
update_epochs: int = 2
"the K epochs to update the policy"
"""the K epochs to update the policy"""
norm_adv: bool = False
"Toggles advantages normalization"
clip_coef: float = 0.2
"the surrogate clipping coefficient"
"""Toggles advantages normalization"""
clip_coef: float = 0.25
"""the surrogate clipping coefficient"""
ent_coef: float = 0.01
"coefficient of the entropy"
"""coefficient of the entropy"""
vf_coef: float = 0.5
"coefficient of the value function"
"""coefficient of the value function"""
max_grad_norm: float = 1.0
"the maximum norm for the gradient clipping"
"""the maximum norm for the gradient clipping"""
num_layers: int = 2
"""the number of layers for the agent"""
......@@ -98,21 +102,21 @@ class Args:
"""the number of channels for the agent"""
actor_device_ids: List[int] = field(default_factory=lambda: [0])
"the device ids that actor workers will use"
"""the device ids that actor workers will use"""
learner_device_ids: List[int] = field(default_factory=lambda: [1])
"the device ids that learner workers will use"
"""the device ids that learner workers will use"""
distributed: bool = False
"whether to use `jax.distirbuted`"
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"whether to run the actor and learner concurrently"
bfloat16: bool = False
"whether to use bfloat16 for the agent"
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"whether to use thread affinity for the environment"
"""whether to use thread affinity for the environment"""
local_eval_episodes: int = 32
"""the number of episodes to evaluate the model"""
eval_interval: int = 10
eval_interval: int = 50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
......@@ -147,6 +151,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
deck2=args.deck2,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
play_mode=mode,
)
envs.num_envs = num_envs
......@@ -158,7 +163,6 @@ class Transition(NamedTuple):
dones: list
actions: list
logprobs: list
values: list
rewards: list
learns: list
......@@ -166,8 +170,7 @@ class Transition(NamedTuple):
def create_agent(args):
return PPOAgent(
channels=args.num_channels,
num_card_layers=args.num_layers,
num_action_layers=args.num_layers,
num_layers=args.num_layers,
embedding_shape=args.num_embeddings,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
param_dtype=jnp.float32,
......@@ -205,38 +208,38 @@ def rollout(
global_step = 0
start_time = time.time()
warmup_step = 0
other_time = 0
avg_ep_returns = deque(maxlen=1000)
avg_win_rates = deque(maxlen=1000)
@jax.jit
def apply_fn(
def get_logits(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
next_obs,
):
logits, value, _valid = create_agent(args).apply(params, next_obs)
return logits, value
return create_agent(args).apply(params, next_obs)[0]
def get_action(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
next_obs,
):
return apply_fn(params, next_obs)[0].argmax(axis=1)
return get_logits(params, next_obs).argmax(axis=1)
@jax.jit
def get_action_and_value(
def sample_action(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
next_obs,
key: jax.random.PRNGKey,
):
next_obs = jax.tree_map(lambda x: jnp.array(x), next_obs)
logits, value = apply_fn(params, next_obs)
logits = get_logits(params, next_obs)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
return next_obs, action, logprob, value.squeeze(), key
return next_obs, action, logprob, key
# put data in the last index
params_queue_get_time = deque(maxlen=10)
......@@ -250,7 +253,8 @@ def rollout(
np.ones(args.local_num_envs // 2, dtype=np.int64)
])
np.random.shuffle(ai_player1)
next_value1 = next_value2 = 0
start_step = 0
storage = []
@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
......@@ -278,23 +282,18 @@ def rollout(
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
storage = []
for _ in range(0, args.num_steps):
global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids * args.world_size
for _ in range(start_step, args.collect_length):
global_step += args.local_num_envs * n_actors * args.world_size
cached_next_obs = next_obs
cached_next_done = next_done
learn = next_to_play == ai_player1
inference_time_start = time.time()
cached_next_obs, action, logprob, value, key = get_action_and_value(params, cached_next_obs, key)
cached_next_obs, action, logprob, key = sample_action(params, cached_next_obs, key)
cpu_action = np.array(action)
inference_time += time.time() - inference_time_start
next_nonterminal = 1 - next_done.astype(np.float32)
next_value1 = np.where(learn, value, next_value1) * next_nonterminal
next_value2 = np.where(learn, next_value2, value) * next_nonterminal
_start = time.time()
to_play = next_to_play
next_obs, next_reward, next_done, info = envs.step(cpu_action)
......@@ -307,7 +306,6 @@ def rollout(
dones=cached_next_done,
actions=action,
logprobs=logprob,
values=value,
rewards=next_reward,
learns=learn,
)
......@@ -324,7 +322,10 @@ def rollout(
rollout_time.append(time.time() - rollout_time_start)
start_step = args.collect_length - args.num_steps
partitioned_storage = prepare_data(storage)
storage = storage[args.num_steps:]
sharded_storage = []
for x in partitioned_storage:
if isinstance(x, dict):
......@@ -339,7 +340,7 @@ def rollout(
next_learn = ai_player1 == next_to_play
sharded_data = jax.tree_map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices),
(next_obs, next_done, next_value1, next_value2, next_learn))
(next_obs, next_done, next_learn))
payload = (
global_step,
actor_policy_version,
......@@ -353,31 +354,23 @@ def rollout(
if update % args.log_frequency == 0:
avg_episodic_return = np.mean(avg_ep_returns)
SPS = int((global_step - warmup_step) / (time.time() - start_time))
avg_episodic_length = np.mean(envs.returned_episode_lengths)
SPS = int((global_step - warmup_step) / (time.time() - start_time - other_time))
SPS_update = int(args.batch_size / (time.time() - update_time_start))
if device_thread_id == 0:
print(
f"global_step={global_step}, avg_episodic_return={avg_episodic_return}, rollout_time={np.mean(rollout_time)}"
f"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
print("SPS:", SPS)
time_now = datetime.now(timezone(timedelta(hours=8))).strftime("%H:%M:%S")
print(f"{time_now} SPS: {SPS}, update: {SPS_update}")
writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", np.mean(envs.returned_episode_lengths), global_step)
writer.add_scalar("charts/avg_episodic_length", avg_episodic_length, global_step)
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/env_time", env_time, global_step)
writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar(
"charts/SPS_update",
int(
args.local_num_envs
* args.num_steps
* len_actor_device_ids
* args.num_actor_threads
* args.world_size
/ (time.time() - update_time_start)
),
global_step,
)
writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
......@@ -395,6 +388,7 @@ def rollout(
if device_thread_id == 0:
eval_time = time.time() - _start
print(f"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}")
other_time += eval_time
if __name__ == "__main__":
......@@ -423,6 +417,8 @@ if __name__ == "__main__":
args.minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.local_env_threads = args.local_env_threads or args.local_num_envs
args.collect_length = args.collect_length or args.num_steps
assert args.collect_length >= args.num_steps, "collect_length must be greater than or equal to num_steps"
local_devices = jax.local_devices()
global_devices = jax.devices()
......@@ -460,11 +456,13 @@ if __name__ == "__main__":
args.deck2 = args.deck2 or deck
# env setup
envs = make_env(
args, args.seed, args.local_num_envs, 1)
envs = make_env(args, args.seed, 8, 1)
obs_space = envs.observation_space
action_shape = envs.action_space.shape
print(f"obs_space={obs_space}, action_shape={action_shape}")
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros((args.local_num_envs,) + x.shape[1:])]), obs_space.sample())
envs.close()
del envs
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
......@@ -473,7 +471,6 @@ if __name__ == "__main__":
return args.learning_rate * frac
agent = create_agent(args)
sample_obs = jax.tree_map(lambda x: jnp.array([np.zeros_like(x)]), obs_space.sample())
params = agent.init(agent_key, sample_obs)
tx = optax.MultiSteps(
optax.chain(
......@@ -482,12 +479,12 @@ if __name__ == "__main__":
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
every_k_schedule=args.gradient_accumulation_steps,
every_k_schedule=1,
)
agent_state = TrainState.create(
apply_fn=None,
params=params,
tx=tx,
apply_fn=None,
params=params,
tx=tx,
)
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
......@@ -507,69 +504,9 @@ if __name__ == "__main__":
entropy = -p_log_p.sum(-1)
return logprob, entropy, value.squeeze(), valid
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda)
@jax.jit
def compute_gae(
agent_state: TrainState,
next_obs: np.ndarray,
next_done: np.ndarray,
next_value1: np.ndarray,
next_value2: np.ndarray,
next_learn: np.ndarray,
storage: Transition,
):
next_value = create_agent(args).apply(agent_state.params, next_obs)[1].squeeze()
next_value1 = jnp.where(next_learn, next_value, next_value1)
next_value2 = jnp.where(next_learn, next_value2, next_value)
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
compute_gae_once, carry, (dones[1:], storage.values, storage.rewards, storage.learns), reverse=True
)
return advantages
def ppo_loss(params, obs, actions, behavior_logprobs, advantages, target_values):
def ppo_loss(params, obs, actions, logprobs, advantages, target_values):
newlogprob, entropy, newvalue, valid = get_logprob_entropy_value(params, obs, actions)
logratio = newlogprob - behavior_logprobs
logratio = newlogprob - logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
......@@ -596,34 +533,48 @@ if __name__ == "__main__":
sharded_storages: List,
sharded_next_obs: List,
sharded_next_done: List,
sharded_next_value1: List,
sharded_next_value2: List,
sharded_next_learn: List,
key: jax.random.PRNGKey,
):
def flatten(x):
return x.reshape((-1,) + x.shape[2:])
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jax.tree_map(lambda *x: jnp.concatenate(x), *sharded_next_obs)
next_done, next_value1, next_value2, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_value1, sharded_next_value2, sharded_next_learn]
next_done, next_learn = [
jnp.concatenate(x) for x in [sharded_next_done, sharded_next_learn]
]
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
advantages = compute_gae(
agent_state, next_obs, next_done, next_value1, next_value2, next_learn, storage)
target_values = advantages + storage.values
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
def flatten(x):
return x.reshape((-1,) + x.shape[2:])
def get_value_minibatch(agent_state, mb_obs):
values = create_agent(args).apply(agent_state.params, mb_obs)[1].squeeze(-1)
return agent_state, values
flatten_obs = jax.tree_map(lambda x: x.reshape((-1, args.local_minibatch_size * 8) + x.shape[2:]), storage.obs)
_, values = jax.lax.scan(
get_value_minibatch, agent_state, flatten_obs)
values = values.reshape(storage.rewards.shape)
next_value = create_agent(args).apply(agent_state.params, next_obs)[1].squeeze(-1)
compute_gae_fn = compute_gae_upgo if args.upgo else compute_gae
advantages, target_values = compute_gae_fn(
next_value, next_done, next_learn,
values, storage.rewards, storage.dones, storage.learns,
args.gamma, args.gae_lambda)
advantages = advantages[:args.num_steps]
target_values = target_values[:args.num_steps]
def convert_data(x: jnp.ndarray):
x = jax.random.permutation(subkey, x)
x = jnp.reshape(x, (args.num_minibatches * args.gradient_accumulation_steps, -1) + x.shape[1:])
x = jnp.reshape(x, (-1, args.local_minibatch_size) + x.shape[1:])
return x
flatten_storage = jax.tree_map(flatten, storage)
flatten_storage = jax.tree_map(flatten, jax.tree_map(lambda x: x[:args.num_steps], storage))
flatten_advantages = flatten(advantages)
flatten_target_values = flatten(target_values)
shuffled_storage = jax.tree_map(convert_data, flatten_storage)
......@@ -631,12 +582,12 @@ if __name__ == "__main__":
shuffled_target_values = convert_data(flatten_target_values)
def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_behavior_logprobs, mb_advantages, mb_target_values = minibatch
mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_target_values = minibatch
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params,
mb_obs,
mb_actions,
mb_behavior_logprobs,
mb_logprobs,
mb_advantages,
mb_target_values,
)
......@@ -709,8 +660,6 @@ if __name__ == "__main__":
sharded_storages = []
sharded_next_obss = []
sharded_next_dones = []
sharded_next_values1 = []
sharded_next_values2 = []
sharded_next_learns = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
......@@ -721,8 +670,6 @@ if __name__ == "__main__":
sharded_storage,
sharded_next_obs,
sharded_next_done,
sharded_next_value1,
sharded_next_value2,
sharded_next_learn,
avg_params_queue_get_time,
device_thread_id,
......@@ -730,8 +677,6 @@ if __name__ == "__main__":
sharded_storages.append(sharded_storage)
sharded_next_obss.append(sharded_next_obs)
sharded_next_dones.append(sharded_next_done)
sharded_next_values1.append(sharded_next_value1)
sharded_next_values2.append(sharded_next_value2)
sharded_next_learns.append(sharded_next_learn)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
......@@ -740,8 +685,6 @@ if __name__ == "__main__":
sharded_storages,
sharded_next_obss,
sharded_next_dones,
sharded_next_values1,
sharded_next_values2,
sharded_next_learns,
learner_keys,
)
......@@ -765,7 +708,7 @@ if __name__ == "__main__":
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
f"actor_update={update}, train_time={time.time() - training_time_start:.2f}",
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
......@@ -775,8 +718,26 @@ if __name__ == "__main__":
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0:
ckpt_dir = f"checkpoints/{run_name}"
os.makedirs(ckpt_dir, exist_ok=True)
model_path = ckpt_dir + "/agent.flax_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
vars(args),
unreplicated_params,
]
)
)
print(f"model saved to {model_path}")
if learner_policy_version >= args.num_updates:
break
envs.close()
if args.distributed:
jax.distributed.shutdown()
writer.close()
\ No newline at end of file
......@@ -18,9 +18,12 @@ class RecordEpisodeStatistics(gym.Wrapper):
return observations, infos
def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action)
return self.update_stats_and_infos(*super().step(action))
def update_stats_and_infos(self, *args):
observations, rewards, terminated, truncated, infos = args
dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards
self.episode_returns += infos.get("reward", rewards)
self.episode_lengths += 1
self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns
......@@ -32,6 +35,19 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return (
observations,
rewards,
......@@ -39,6 +55,19 @@ class RecordEpisodeStatistics(gym.Wrapper):
infos,
)
def async_reset(self):
self.env.async_reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def recv(self):
return self.update_stats_and_infos(*self.env.recv())
def send(self, action):
return self.env.send(action)
class CompatEnv(gym.Wrapper):
......
from functools import partial
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
q_estimate: jnp.ndarray
errors: jnp.ndarray
def vtrace(
v_tm1,
v_t,
r_t,
discount_t,
rho_tm1,
lambda_=1.0,
c_clip_min: float = 0.001,
c_clip_max: float = 1.007,
rho_clip_min: float = 0.001,
rho_clip_max: float = 1.007,
stop_target_gradients: bool = True,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_ = jnp.ones_like(discount_t) * lambda_
c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# Compute the temporal difference errors.
td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# Work backwards computing the td-errors.
def _body(acc, xs):
td_error, discount, c = xs
acc = td_error + discount * c * acc
return acc, acc
_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors = jax.lax.select(
stop_target_gradients,
jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
errors)
targets_tm1 = errors + v_tm1
q_bootstrap = jnp.concatenate([
lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
v_t[-1:],
], axis=0)
q_estimate = r_t + discount_t * q_bootstrap
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def upgo_return(r_t, v_t, discount_t, stop_target_gradients: bool = True):
def _body(acc, xs):
r, v, q, discount = xs
acc = r + discount * jnp.where(q >= v, acc, v)
return acc, acc
# TODO: following alphastar, estimate q_t with one-step target
# It might be better to use network to estimate q_t
q_t = r_t[1:] + discount_t[1:] * v_t[1:] # q[:-1]
_, returns = jax.lax.scan(
_body, q_t[-1], (r_t[:-1], v_t[:-1], q_t, discount_t[:-1]), reverse=True)
# Following rlax.vtrace_td_error_and_advantage, part of gradient is reserved
# Experiments show that where to stop gradient has no impact on the performance
returns = jax.lax.select(
stop_target_gradients, jax.lax.stop_gradient(returns), returns)
returns = jnp.concatenate([returns, q_t[-1:]], axis=0)
return returns
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
return -jnp.mean(clipped_objective * mask)
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
target_values = advantages + values
return advantages, target_values
def compute_gae_once_upgo(carry, inp, gamma, gae_lambda):
next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
next_value1 = jnp.where(real_done1, 0, next_value1)
last_return1 = jnp.where(real_done1, 0, last_return1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
next_value2 = jnp.where(real_done2, 0, next_value2)
last_return2 = jnp.where(real_done2, 0, last_return2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2)
next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2
delta1 = next_q1_ - curvalues
delta2 = next_q2_ - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
returns = jnp.where(learn1, last_return1_, last_return2_)
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
next_value1 = jnp.where(learn1, curvalues, next_value1)
next_value2 = jnp.where(learn2, curvalues, next_value2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
next_q1 = jnp.where(learn1, next_q1_, next_q1)
next_q2 = jnp.where(learn2, next_q2_, next_q1)
last_return1 = jnp.where(learn1, last_return1_, last_return1)
last_return2 = jnp.where(learn2, last_return2_, last_return2)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, (advantages, returns)
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae_upgo(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, (advantages, returns) = jax.lax.scan(
partial(compute_gae_once_upgo, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
return returns - values, advantages + values
......@@ -5,54 +5,13 @@ import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
def decode_id(x):
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
default_embed_init = nn.initializers.uniform(scale=0.0001)
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.0001)
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=self.kernel_init, use_bias=False)(x)
if i < n - 1 or not self.last_lin:
x = nn.relu(x)
return x
default_fc_init2 = nn.initializers.uniform(scale=0.001)
class ActionEncoder(nn.Module):
......@@ -105,18 +64,19 @@ class Encoder(nn.Module):
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
fc_layer = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = MLP((c // 8,), last_lin=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = MLP((c // 8,), last_lin=False, dtype=jnp.float32, param_dtype=self.param_dtype)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=self.dtype, param_dtype=self.param_dtype)
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
......@@ -125,12 +85,12 @@ class Encoder(nn.Module):
valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32)
x_cards_2 = x_cards[:, :, 12:].astype(self.dtype or jnp.float32)
x_cards_2 = x_cards[:, :, 12:].astype(jnp.float32)
x_id = decode_id(x_cards_1[:, :, :2])
x_id = id_embed(x_id)
x_id = MLP(
(c, c // 4), dtype=self.dtype, param_dtype=self.param_dtype,
(c, c // 4), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
......@@ -152,10 +112,10 @@ class Encoder(nn.Module):
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11])
x_atk = num_transform(x_cards_2[:, :, 0:2])
x_atk = fc_layer(c // 16, kernel_init=default_fc_init1)(x_atk)
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4])
x_def = fc_layer(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_layer(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
x_feat = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
......@@ -173,14 +133,14 @@ class Encoder(nn.Module):
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1))
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(self.dtype or jnp.float32)
x_g_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_1 = x_global[:, :4].astype(jnp.float32)
x_g_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0])
......@@ -197,7 +157,7 @@ class Encoder(nn.Module):
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global)
......@@ -220,14 +180,14 @@ class Encoder(nn.Module):
f_actions, f_cards,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_'].astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=self.dtype, param_dtype=self.param_dtype,
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
......@@ -237,9 +197,9 @@ class Encoder(nn.Module):
for _ in range(self.num_action_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions = DecoderLayer(num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(
f_actions, f_h_actions,
tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask)
......@@ -261,11 +221,12 @@ class Actor(nn.Module):
@nn.compact
def __call__(self, f_actions, mask):
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
num_heads = max(2, c // 128)
f_actions = EncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = MLP((c // 4, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_actions)
logits = logits[..., 0].astype(jnp.float32)
num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = mlp((c // 4, 1), use_bias=True)(f_actions)
logits = logits[..., 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
......@@ -279,8 +240,8 @@ class Critic(nn.Module):
@nn.compact
def __call__(self, f_state):
c = self.channels
x = MLP((c // 2, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
x = x.astype(jnp.float32)
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(1.0))
x = MLP((c // 2, 1), use_bias=True)(f_state)
return x
......
from typing import Tuple, Union, Optional, Sequence
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
default_embed_init = nn.initializers.uniform(scale=0.001)
default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001)
class ActionEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
div = 8
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype,
embedding_init=default_embed_init)
x_a_msg = embed(30, c // div)(x[:, :, 0])
x_a_act = embed(13, c // div)(x[:, :, 1])
x_a_yesno = embed(3, c // div)(x[:, :, 2])
x_a_phase = embed(4, c // div)(x[:, :, 3])
x_a_cancel = embed(3, c // div)(x[:, :, 4])
x_a_finish = embed(3, c // div // 2)(x[:, :, 5])
x_a_position = embed(9, c // div // 2)(x[:, :, 6])
x_a_option = embed(6, c // div // 2)(x[:, :, 7])
x_a_number = embed(13, c // div // 2)(x[:, :, 8])
x_a_place = embed(31, c // div // 2)(x[:, :, 9])
x_a_attrib = embed(10, c // div // 2)(x[:, :, 10])
return jnp.concatenate([
x_a_msg, x_a_act, x_a_yesno, x_a_phase, x_a_cancel, x_a_finish,
x_a_position, x_a_option, x_a_number, x_a_place, x_a_attrib], axis=-1)
class CardEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x_id, x):
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :, :10].astype(jnp.int32)
x2 = x[:, :, 10:].astype(jnp.float32)
x_id = mlp(
(c, c // 4), kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0]
f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1]
f_seq = layer_norm()(embed(76, c)(x_seq))
x_owner = embed(2, c // 16)(x1[:, :, 2])
x_position = embed(9, c // 16)(x1[:, :, 3])
x_overley = embed(2, c // 16)(x1[:, :, 4])
x_attribute = embed(8, c // 16)(x1[:, :, 5])
x_race = embed(27, c // 16)(x1[:, :, 6])
x_level = embed(14, c // 16)(x1[:, :, 7])
x_counter = embed(16, c // 16)(x1[:, :, 8])
x_negated = embed(3, c // 16)(x1[:, :, 9])
x_atk = num_transform(x2[:, :, 0:2])
x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x2[:, :, 2:4])
x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x2[:, :, 4:])
x_f = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute,
x_race, x_level, x_counter, x_negated,
x_atk, x_def, x_type], axis=-1)
x_f = layer_norm()(x_f)
f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq
return f_cards
class GlobalEncoder(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
batch_size = x.shape[0]
c = self.channels
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_embed = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype)
count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16)
num_fc = mlp((c // 8,), last_lin=False)
bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
x1 = x[:, :4].astype(jnp.float32)
x2 = x[:, 4:8].astype(jnp.int32)
x3 = x[:, 8:22].astype(jnp.int32)
x_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 0:2]))
x_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x1[:, 2:4]))
x_turn = embed(20, c // 8)(x2[:, 0])
x_phase = embed(11, c // 8)(x2[:, 1])
x_if_first = embed(2, c // 8)(x2[:, 2])
x_is_my_turn = embed(2, c // 8)(x2[:, 3])
x_cs = count_embed(x3).reshape((batch_size, -1))
x_my_hand_c = hand_count_embed(x3[:, 1])
x_op_hand_c = hand_count_embed(x3[:, 8])
x = jnp.concatenate([
x_lp, x_oppo_lp, x_turn, x_phase, x_if_first, x_is_my_turn,
x_cs, x_my_hand_c, x_op_hand_c], axis=-1)
x = layer_norm()(x)
return x
class Encoder(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
if self.embedding_shape is None:
n_embed, embed_dim = 999, 1024
elif isinstance(self.embedding_shape, int):
n_embed, embed_dim = self.embedding_shape, 1024
else:
n_embed, embed_dim = self.embedding_shape
n_embed = 1 + n_embed # 1 (index 0) for unknown
layer_norm = partial(nn.LayerNorm, use_scale=True, use_bias=True)
embed = partial(
nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim)
action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_']
x_global = x['global_']
x_actions = x['actions_']
x_h_actions = x['h_actions_']
batch_size = x_cards.shape[0]
valid = x_global[:, -1] == 0
x_id = decode_id(x_cards[:, :, :2].astype(jnp.int32))
x_id = id_embed(x_id)
# Cards
f_cards = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param(
'g_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
num_heads = max(2, c // 128)
for _ in range(self.num_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards)
f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0]
# Global
x_global = GlobalEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
x_global = x_global.astype(self.dtype)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c, dtype=self.dtype)(f_global)
f_global = layer_norm(dtype=self.dtype)(f_global)
# History actions
x_h_actions = x_h_actions.astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP(
(c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
f_h_actions = layer_norm()(x_h_id) + layer_norm()(fc_layer(c, dtype=jnp.float32)(x_h_a_feats))
f_h_actions = PositionalEncoding()(f_h_actions)
for _ in range(self.num_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask)
f_g_h_actions = layer_norm(dtype=self.dtype)(f_h_actions[:, 0])
# Actions
x_actions = x_actions.astype(jnp.int32)
na_card_embed = self.param(
'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards[:, 1:]], axis=1)
spec_index = decode_id(x_actions[..., :2])
B = jnp.arange(batch_size)
f_a_cards = f_cards[B[:, None], spec_index]
f_a_cards = fc_layer(c, dtype=self.dtype)(f_a_cards)
x_a_feats = action_encoder(x_actions[..., 2:])
x_a_feats = fc_layer(c, dtype=self.dtype)(x_a_feats)
f_actions = jnp.concatenate([f_a_cards, x_a_feats], axis=-1)
f_actions = fc_layer(c, dtype=self.dtype)(nn.leaky_relu(f_actions, negative_slope=0.1))
f_actions = layer_norm(dtype=self.dtype)(f_actions)
a_mask = x_actions[:, :, 2] == 0
a_mask = a_mask.at[:, 0].set(False)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
# State
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
f_state = MLP((c * 2, c), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
# TODO: LSTM
return f_actions, f_state, a_mask, valid
class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, f_state):
f_state = f_state.astype(self.dtype)
mlp = partial(MLP, dtype=self.dtype, param_dtype=self.param_dtype)
x = mlp(self.channels, last_lin=False)(f_state)
x = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))(x)
return x
class PPOAgent(nn.Module):
channels: int = 128
num_layers: int = 2
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
c = self.channels
encoder = Encoder(
channels=c,
num_layers=self.num_layers,
embedding_shape=self.embedding_shape,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
f_actions, f_state, mask, valid = encoder(x)
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return logits, value, valid
from typing import Tuple, Union, Optional
import jax.numpy as jnp
import flax.linen as nn
def decode_id(x):
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=12000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
if self.last_lin and i == n - 1:
kernel_init = self.last_kernel_init
else:
kernel_init = self.kernel_init
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=kernel_init, use_bias=self.use_bias)(x)
if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1)
return x
......@@ -632,6 +632,7 @@ class EncoderLayer(nn.Module):
@nn.compact
def __call__(self, inputs, src_key_padding_mask=None):
inputs = jnp.asarray(inputs, self.dtype)
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention(
......
import numpy as np
import torch
from torch.distributions import Categorical
from torch.cuda.amp import autocast
import torch_xla.core.xla_model as xm
from ygoai.rl.utils import masked_normalize, masked_mean
def entropy_from_logits(logits):
min_real = torch.finfo(logits.dtype).min
logits = torch.clamp(logits, min=min_real)
p_log_p = logits * torch.softmax(logits, dim=-1)
return -p_log_p.sum(-1)
def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
with autocast(enabled=args.fp16_train):
logits, newvalue, valid = agent(mb_obs)
probs = Categorical(logits=logits)
newlogprob = probs.log_prob(mb_actions)
entropy = probs.entropy()
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
......@@ -57,6 +66,108 @@ def train_step(agent, optimizer, scaler, mb_obs, mb_actions, mb_logprobs, mb_adv
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns, mb_inds, args):
mb_obs = {
k: v[mb_inds] for k, v in b_obs.items()
}
mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns = [
v[mb_inds] for v in [b_actions, b_logprobs, b_advantages, b_returns, b_values, b_learns]]
optimizer.zero_grad(True)
logits, newvalue, valid = agent(mb_obs)
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
entropy = entropy_from_logits(logits)
valid = torch.logical_and(valid, mb_learns)
logratio = newlogprob - mb_logprobs
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
if args.norm_adv:
mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = masked_mean(pg_loss, valid)
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_clipped = mb_values + torch.clamp(
newvalue - mb_values,
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max
else:
v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
v_loss = masked_mean(v_loss, valid)
entropy_loss = masked_mean(entropy, valid)
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
loss.backward()
xm.optimizer_step(optimizer)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
# logits, newvalue, valid = agent(mb_obs)
# logits = logits - logits.logsumexp(dim=-1, keepdim=True)
# newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
# entropy = entropy_from_logits(logits)
# valid = torch.logical_and(valid, mb_learns)
# logratio = newlogprob - mb_logprobs
# ratio = logratio.exp()
# with torch.no_grad():
# # calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
# approx_kl = ((ratio - 1) - logratio).mean()
# clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
# if args.norm_adv:
# mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# # Policy loss
# pg_loss1 = -mb_advantages * ratio
# pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# pg_loss = torch.max(pg_loss1, pg_loss2)
# pg_loss = masked_mean(pg_loss, valid)
# # Value loss
# newvalue = newvalue.view(-1)
# if args.clip_vloss:
# v_loss_unclipped = (newvalue - mb_returns) ** 2
# v_clipped = mb_values + torch.clamp(
# newvalue - mb_values,
# -args.clip_coef,
# args.clip_coef,
# )
# v_loss_clipped = (v_clipped - mb_returns) ** 2
# v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
# v_loss = 0.5 * v_loss_max
# else:
# v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
# v_loss = masked_mean(v_loss, valid)
# entropy_loss = masked_mean(entropy, valid)
# loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
# loss.backward()
# optimizer.step()
# return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def bootstrap_value(values, rewards, dones, nextvalues, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
......@@ -206,4 +317,115 @@ def bootstrap_value_selfplay(values, rewards, dones, learns, nextvalues1, nextva
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
def bootstrap_value_selfplay_upgo(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.size(0)
advantages = torch.zeros_like(rewards)
# TODO: optimize this
done_used1 = torch.ones_like(next_done, dtype=torch.bool)
done_used2 = torch.ones_like(next_done, dtype=torch.bool)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
# if learns[t]:
# if dones[t+1]:
# reward1 = rewards[t]
# next_values1 = 0
# last_return1 = 0
# lastgaelam1 = 0
# done_used1 = True
#
# reward2 = -rewards[t]
# done_used2 = False
# else:
# if not done_used1:
# reward1 = reward1
# next_values1 = 0
# last_return1 = 0
# lastgaelam1 = 0
# done_used1 = True
# else:
# reward1 = rewards[t]
# reward2 = reward2
# last_return1_ = reward1 + args.gamma * (last_return1 if (next_qs1 >= next_values1) else next_values1)
# next_q1_ = reward1 + args.gamma * next_values1
# delta1 = next_q1_ - values[t]
# lastgaelam1_ = delta1 + args.gamma * args.gae_lambda * lastgaelam1
# returns[t] = last_return1_
# advantages[t] = lastgaelam1_
# next_values1 = values[t]
# lastgaelam1 = lastgaelam1_
# next_qs1 = next_q1_
# last_return1 = last_return1_
# else:
# Skip because it is symmetric
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.int() - 0.5)
reward1 = torch.where(next_done, rewards[t] * sp, torch.where(learn1 & done_used1, 0, reward1))
reward2 = torch.where(next_done, rewards[t] * -sp, torch.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = torch.where(real_done1, 0, nextvalues1)
lastgaelam1 = torch.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = torch.where(real_done2, 0, nextvalues2)
lastgaelam2 = torch.where(real_done2, 0, lastgaelam2)
done_used1 = torch.where(
next_done, learn1, torch.where(learn1 & ~done_used1, True, done_used1))
done_used2 = torch.where(
next_done, learn2, torch.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = torch.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = torch.where(learn1, values[t], nextvalues1)
nextvalues2 = torch.where(learn2, values[t], nextvalues2)
lastgaelam1 = torch.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = torch.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
def bootstrap_value_selfplay_np(values, rewards, dones, learns, nextvalues1, nextvalues2, next_done, gamma, gae_lambda):
num_steps = rewards.shape[0]
advantages = np.zeros_like(rewards)
# TODO: optimize this
done_used1 = np.ones_like(next_done, dtype=np.bool_)
done_used2 = np.ones_like(next_done, dtype=np.bool_)
reward1 = reward2 = 0
lastgaelam1 = lastgaelam2 = 0
for t in reversed(range(num_steps)):
learn1 = learns[t]
learn2 = ~learn1
if t != num_steps - 1:
next_done = dones[t + 1]
sp = 2 * (learn1.astype(np.float32) - 0.5)
reward1 = np.where(next_done, rewards[t] * sp, np.where(learn1 & done_used1, 0, reward1))
reward2 = np.where(next_done, rewards[t] * -sp, np.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = np.where(real_done1, 0, nextvalues1)
lastgaelam1 = np.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = np.where(real_done2, 0, nextvalues2)
lastgaelam2 = np.where(real_done2, 0, lastgaelam2)
done_used1 = np.where(
next_done, learn1, np.where(learn1 & ~done_used1, True, done_used1))
done_used2 = np.where(
next_done, learn2, np.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - values[t]
delta2 = reward2 + gamma * nextvalues2 - values[t]
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages[t] = np.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = np.where(learn1, values[t], nextvalues1)
nextvalues2 = np.where(learn2, values[t], nextvalues2)
lastgaelam1 = np.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = np.where(learn2, lastgaelam2_, lastgaelam2)
return advantages
\ No newline at end of file
......@@ -55,7 +55,7 @@ def masked_normalize(x, valid, eps=1e-8):
return (x - mean) / std
def to_tensor(x, device, dtype=torch.float32):
def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
......
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
#ifndef BS_THREAD_POOL_HPP
#define BS_THREAD_POOL_HPP
/**
* @file BS_thread_pool.hpp
* @author Barak Shoshany (baraksh@gmail.com) (https://baraksh.com)
* @version 4.1.0
* @date 2024-03-22
* @copyright Copyright (c) 2024 Barak Shoshany. Licensed under the MIT license. If you found this project useful, please consider starring it on GitHub! If you use this library in software of any kind, please provide a link to the GitHub repository https://github.com/bshoshany/thread-pool in the source code and documentation. If you use this library in published research, please cite it as follows: Barak Shoshany, "A C++17 Thread Pool for High-Performance Scientific Computing", doi:10.1016/j.softx.2024.101687, SoftwareX 26 (2024) 101687, arXiv:2105.00613
*
* @brief BS::thread_pool: a fast, lightweight, and easy-to-use C++17 thread pool library. This header file contains the main thread pool class and some additional classes and definitions. No other files are needed in order to use the thread pool itself.
*/
#ifndef __cpp_exceptions
#define BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
#undef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
#endif
#include <chrono> // std::chrono
#include <condition_variable> // std::condition_variable
#include <cstddef> // std::size_t
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
#include <cstdint> // std::int_least16_t
#endif
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
#include <exception> // std::current_exception
#endif
#include <functional> // std::function
#include <future> // std::future, std::future_status, std::promise
#include <memory> // std::make_shared, std::make_unique, std::shared_ptr, std::unique_ptr
#include <mutex> // std::mutex, std::scoped_lock, std::unique_lock
#include <optional> // std::nullopt, std::optional
#include <queue> // std::priority_queue (if priority enabled), std::queue
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
#include <stdexcept> // std::runtime_error
#endif
#include <thread> // std::thread
#include <type_traits> // std::conditional_t, std::decay_t, std::invoke_result_t, std::is_void_v, std::remove_const_t (if priority enabled)
#include <utility> // std::forward, std::move
#include <vector> // std::vector
/**
* @brief A namespace used by Barak Shoshany's projects.
*/
namespace BS {
// Macros indicating the version of the thread pool library.
#define BS_THREAD_POOL_VERSION_MAJOR 4
#define BS_THREAD_POOL_VERSION_MINOR 1
#define BS_THREAD_POOL_VERSION_PATCH 0
class thread_pool;
/**
* @brief A type to represent the size of things.
*/
using size_t = std::size_t;
/**
* @brief A convenient shorthand for the type of `std::thread::hardware_concurrency()`. Should evaluate to unsigned int.
*/
using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
/**
* @brief A type used to indicate the priority of a task. Defined to be an integer with a width of (at least) 16 bits.
*/
using priority_t = std::int_least16_t;
/**
* @brief A namespace containing some pre-defined priorities for convenience.
*/
namespace pr {
constexpr priority_t highest = 32767;
constexpr priority_t high = 16383;
constexpr priority_t normal = 0;
constexpr priority_t low = -16384;
constexpr priority_t lowest = -32768;
} // namespace pr
// Macros used internally to enable or disable the priority arguments in the relevant functions.
#define BS_THREAD_POOL_PRIORITY_INPUT , const priority_t priority = 0
#define BS_THREAD_POOL_PRIORITY_OUTPUT , priority
#else
#define BS_THREAD_POOL_PRIORITY_INPUT
#define BS_THREAD_POOL_PRIORITY_OUTPUT
#endif
/**
* @brief A namespace used to obtain information about the current thread.
*/
namespace this_thread {
/**
* @brief A type returned by `BS::this_thread::get_index()` which can optionally contain the index of a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value.
*/
using optional_index = std::optional<size_t>;
/**
* @brief A type returned by `BS::this_thread::get_pool()` which can optionally contain the pointer to the pool that owns a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value.
*/
using optional_pool = std::optional<thread_pool*>;
/**
* @brief A helper class to store information about the index of the current thread.
*/
class [[nodiscard]] thread_info_index
{
friend class BS::thread_pool;
public:
/**
* @brief Get the index of the current thread. If this thread belongs to a `BS::thread_pool` object, it will have an index from 0 to `BS::thread_pool::get_thread_count() - 1`. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned.
*
* @return An `std::optional` object, optionally containing a thread index. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value.
*/
[[nodiscard]] optional_index operator()() const
{
return index;
}
private:
/**
* @brief The index of the current thread.
*/
optional_index index = std::nullopt;
}; // class thread_info_index
/**
* @brief A helper class to store information about the thread pool that owns the current thread.
*/
class [[nodiscard]] thread_info_pool
{
friend class BS::thread_pool;
public:
/**
* @brief Get the pointer to the thread pool that owns the current thread. If this thread belongs to a `BS::thread_pool` object, a pointer to that object will be returned. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned.
*
* @return An `std::optional` object, optionally containing a pointer to a thread pool. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value.
*/
[[nodiscard]] optional_pool operator()() const
{
return pool;
}
private:
/**
* @brief A pointer to the thread pool that owns the current thread.
*/
optional_pool pool = std::nullopt;
}; // class thread_info_pool
/**
* @brief A `thread_local` object used to obtain information about the index of the current thread.
*/
inline thread_local thread_info_index get_index;
/**
* @brief A `thread_local` object used to obtain information about the thread pool that owns the current thread.
*/
inline thread_local thread_info_pool get_pool;
} // namespace this_thread
/**
* @brief A helper class to facilitate waiting for and/or getting the results of multiple futures at once.
*
* @tparam T The return type of the futures.
*/
template <typename T>
class [[nodiscard]] multi_future : public std::vector<std::future<T>>
{
public:
// Inherit all constructors from the base class `std::vector`.
using std::vector<std::future<T>>::vector;
// The copy constructor and copy assignment operator are deleted. The elements stored in a `multi_future` are futures, which cannot be copied.
multi_future(const multi_future&) = delete;
multi_future& operator=(const multi_future&) = delete;
// The move constructor and move assignment operator are defaulted.
multi_future(multi_future&&) = default;
multi_future& operator=(multi_future&&) = default;
/**
* @brief Get the results from all the futures stored in this `multi_future`, rethrowing any stored exceptions.
*
* @return If the futures return `void`, this function returns `void` as well. Otherwise, it returns a vector containing the results.
*/
[[nodiscard]] std::conditional_t<std::is_void_v<T>, void, std::vector<T>> get()
{
if constexpr (std::is_void_v<T>)
{
for (std::future<T>& future : *this)
future.get();
return;
}
else
{
std::vector<T> results;
results.reserve(this->size());
for (std::future<T>& future : *this)
results.push_back(future.get());
return results;
}
}
/**
* @brief Check how many of the futures stored in this `multi_future` are ready.
*
* @return The number of ready futures.
*/
[[nodiscard]] size_t ready_count() const
{
size_t count = 0;
for (const std::future<T>& future : *this)
{
if (future.wait_for(std::chrono::duration<double>::zero()) == std::future_status::ready)
++count;
}
return count;
}
/**
* @brief Check if all the futures stored in this `multi_future` are valid.
*
* @return `true` if all futures are valid, `false` if at least one of the futures is not valid.
*/
[[nodiscard]] bool valid() const
{
bool is_valid = true;
for (const std::future<T>& future : *this)
is_valid = is_valid && future.valid();
return is_valid;
}
/**
* @brief Wait for all the futures stored in this `multi_future`.
*/
void wait() const
{
for (const std::future<T>& future : *this)
future.wait();
}
/**
* @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified duration has passed. This function first waits for the first future for the desired duration. If that future is ready before the duration expires, this function waits for the second future for whatever remains of the duration. It continues similarly until the duration expires.
*
* @tparam R An arithmetic type representing the number of ticks to wait.
* @tparam P An `std::ratio` representing the length of each tick in seconds.
* @param duration The amount of time to wait.
* @return `true` if all futures have been waited for before the duration expired, `false` otherwise.
*/
template <typename R, typename P>
bool wait_for(const std::chrono::duration<R, P>& duration) const
{
const std::chrono::time_point<std::chrono::steady_clock> start_time = std::chrono::steady_clock::now();
for (const std::future<T>& future : *this)
{
future.wait_for(duration - (std::chrono::steady_clock::now() - start_time));
if (duration < std::chrono::steady_clock::now() - start_time)
return false;
}
return true;
}
/**
* @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified time point has been reached. This function first waits for the first future until the desired time point. If that future is ready before the time point is reached, this function waits for the second future until the desired time point. It continues similarly until the time point is reached.
*
* @tparam C The type of the clock used to measure time.
* @tparam D An `std::chrono::duration` type used to indicate the time point.
* @param timeout_time The time point at which to stop waiting.
* @return `true` if all futures have been waited for before the time point was reached, `false` otherwise.
*/
template <typename C, typename D>
bool wait_until(const std::chrono::time_point<C, D>& timeout_time) const
{
for (const std::future<T>& future : *this)
{
future.wait_until(timeout_time);
if (timeout_time < std::chrono::steady_clock::now())
return false;
}
return true;
}
}; // class multi_future
/**
* @brief A fast, lightweight, and easy-to-use C++17 thread pool class.
*/
class [[nodiscard]] thread_pool
{
public:
// ============================
// Constructors and destructors
// ============================
/**
* @brief Construct a new thread pool. The number of threads will be the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
*/
thread_pool() : thread_pool(0, [] {}) {}
/**
* @brief Construct a new thread pool with the specified number of threads.
*
* @param num_threads The number of threads to use.
*/
explicit thread_pool(const concurrency_t num_threads) : thread_pool(num_threads, [] {}) {}
/**
* @brief Construct a new thread pool with the specified initialization function.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
explicit thread_pool(const std::function<void()>& init_task) : thread_pool(0, init_task) {}
/**
* @brief Construct a new thread pool with the specified number of threads and initialization function.
*
* @param num_threads The number of threads to use.
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
thread_pool(const concurrency_t num_threads, const std::function<void()>& init_task) : thread_count(determine_thread_count(num_threads)), threads(std::make_unique<std::thread[]>(determine_thread_count(num_threads)))
{
create_threads(init_task);
}
// The copy and move constructors and assignment operators are deleted. The thread pool uses a mutex, which cannot be copied or moved.
thread_pool(const thread_pool&) = delete;
thread_pool(thread_pool&&) = delete;
thread_pool& operator=(const thread_pool&) = delete;
thread_pool& operator=(thread_pool&&) = delete;
/**
* @brief Destruct the thread pool. Waits for all tasks to complete, then destroys all threads. Note that if the pool is paused, then any tasks still in the queue will never be executed.
*/
~thread_pool()
{
wait();
destroy_threads();
}
// =======================
// Public member functions
// =======================
#ifdef BS_THREAD_POOL_ENABLE_NATIVE_HANDLES
/**
* @brief Get a vector containing the underlying implementation-defined thread handles for each of the pool's threads, as obtained by `std::thread::native_handle()`. Only enabled if `BS_THREAD_POOL_ENABLE_NATIVE_HANDLES` is defined.
*
* @return The native thread handles.
*/
[[nodiscard]] std::vector<std::thread::native_handle_type> get_native_handles() const
{
std::vector<std::thread::native_handle_type> native_handles(thread_count);
for (concurrency_t i = 0; i < thread_count; ++i)
{
native_handles[i] = threads[i].native_handle();
}
return native_handles;
}
#endif
/**
* @brief Get the number of tasks currently waiting in the queue to be executed by the threads.
*
* @return The number of queued tasks.
*/
[[nodiscard]] size_t get_tasks_queued() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks.size();
}
/**
* @brief Get the number of tasks currently being executed by the threads.
*
* @return The number of running tasks.
*/
[[nodiscard]] size_t get_tasks_running() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks_running;
}
/**
* @brief Get the total number of unfinished tasks: either still waiting in the queue, or running in a thread. Note that `get_tasks_total() == get_tasks_queued() + get_tasks_running()`.
*
* @return The total number of tasks.
*/
[[nodiscard]] size_t get_tasks_total() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return tasks_running + tasks.size();
}
/**
* @brief Get the number of threads in the pool.
*
* @return The number of threads.
*/
[[nodiscard]] concurrency_t get_thread_count() const
{
return thread_count;
}
/**
* @brief Get a vector containing the unique identifiers for each of the pool's threads, as obtained by `std::thread::get_id()`.
*
* @return The unique thread identifiers.
*/
[[nodiscard]] std::vector<std::thread::id> get_thread_ids() const
{
std::vector<std::thread::id> thread_ids(thread_count);
for (concurrency_t i = 0; i < thread_count; ++i)
{
thread_ids[i] = threads[i].get_id();
}
return thread_ids;
}
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief Check whether the pool is currently paused. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*
* @return `true` if the pool is paused, `false` if it is not paused.
*/
[[nodiscard]] bool is_paused() const
{
const std::scoped_lock tasks_lock(tasks_mutex);
return paused;
}
/**
* @brief Pause the pool. The workers will temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*/
void pause()
{
const std::scoped_lock tasks_lock(tasks_mutex);
paused = true;
}
#endif
/**
* @brief Purge all the tasks waiting in the queue. Tasks that are currently running will not be affected, but any tasks still waiting in the queue will be discarded, and will never be executed by the threads. Please note that there is no way to restore the purged tasks.
*/
void purge()
{
const std::scoped_lock tasks_lock(tasks_mutex);
while (!tasks.empty())
tasks.pop();
}
/**
* @brief Submit a function with no arguments and no return value into the task queue, with the specified priority. To push a function with arguments, enclose it in a lambda expression. Does not return a future, so the user must use `wait()` or some other method to ensure that the task finishes executing, otherwise bad things will happen.
*
* @tparam F The type of the function.
* @param task The function to push.
* @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename F>
void detach_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT)
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
tasks.emplace(std::forward<F>(task) BS_THREAD_POOL_PRIORITY_OUTPUT);
}
task_available_cv.notify_one();
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted.
* @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
detach_task(
[block = std::forward<F>(block), start = blks.start(blk), end = blks.end(blk)]
{
block(start, end);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted.
* @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
detach_task(
[loop = std::forward<F>(loop), start = blks.start(blk), end = blks.end(blk)]
{
for (T i = start; i < end; ++i)
loop(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
}
/**
* @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the sequence finishes executing, otherwise bad things will happen.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function used to define the sequence.
* @param first_index The first index in the sequence.
* @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted.
* @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
*/
template <typename T, typename F>
void detach_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT)
{
for (T i = first_index; i < index_after_last; ++i)
detach_task(
[sequence = std::forward<F>(sequence), i]
{
sequence(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT);
}
/**
* @brief Reset the pool with the total number of hardware threads available, as reported by the implementation. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*/
void reset()
{
reset(0, [] {});
}
/**
* @brief Reset the pool with a new number of threads. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param num_threads The number of threads to use.
*/
void reset(const concurrency_t num_threads)
{
reset(num_threads, [] {});
}
/**
* @brief Reset the pool with the total number of hardware threads available, as reported by the implementation, and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
void reset(const std::function<void()>& init_task)
{
reset(0, init_task);
}
/**
* @brief Reset the pool with a new number of threads and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param num_threads The number of threads to use.
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed.
*/
void reset(const concurrency_t num_threads, const std::function<void()>& init_task)
{
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
std::unique_lock tasks_lock(tasks_mutex);
const bool was_paused = paused;
paused = true;
tasks_lock.unlock();
#endif
wait();
destroy_threads();
thread_count = determine_thread_count(num_threads);
threads = std::make_unique<std::thread[]>(thread_count);
create_threads(init_task);
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
tasks_lock.lock();
paused = was_paused;
#endif
}
/**
* @brief Submit a function with no arguments into the task queue, with the specified priority. To submit a function with arguments, enclose it in a lambda expression. If the function has a return value, get a future for the eventual returned value. If the function has no return value, get an `std::future<void>` which can be used to wait until the task finishes.
*
* @tparam F The type of the function.
* @tparam R The return type of the function (can be `void`).
* @param task The function to submit.
* @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A future to be used later to wait for the function to finish executing and/or obtain its returned value if it has one.
*/
template <typename F, typename R = std::invoke_result_t<std::decay_t<F>>>
[[nodiscard]] std::future<R> submit_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT)
{
const std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
detach_task(
[task = std::forward<F>(task), task_promise]
{
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
try
{
#endif
if constexpr (std::is_void_v<R>)
{
task();
task_promise->set_value();
}
else
{
task_promise->set_value(task());
}
#ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING
}
catch (...)
{
try
{
task_promise->set_exception(std::current_exception());
}
catch (...)
{
}
}
#endif
} BS_THREAD_POOL_PRIORITY_OUTPUT);
return task_promise->get_future();
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Returns a `multi_future` that contains the futures for all of the blocks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @tparam R The return type of the function to loop through (can be `void`).
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted, and an empty `multi_future` will be returned.
* @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the blocks to finish. If the block function returns a value, the `multi_future` can also be used to obtain the values returned by each block.
*/
template <typename T, typename F, typename R = std::invoke_result_t<std::decay_t<F>, T, T>>
[[nodiscard]] multi_future<R> submit_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
multi_future<R> future;
future.reserve(blks.get_num_blocks());
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
future.push_back(submit_task(
[block = std::forward<F>(block), start = blks.start(blk), end = blks.end(blk)]
{
return block(start, end);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
/**
* @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. It must have no return value. Returns a `multi_future` that contains the futures for all of the blocks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function to loop through.
* @param first_index The first index in the loop.
* @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned.
* @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index. It cannot have a return value.
* @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the blocks to finish.
*/
template <typename T, typename F>
[[nodiscard]] multi_future<void> submit_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count);
multi_future<void> future;
future.reserve(blks.get_num_blocks());
for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk)
future.push_back(submit_task(
[loop = std::forward<F>(loop), start = blks.start(blk), end = blks.end(blk)]
{
for (T i = start; i < end; ++i)
loop(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
/**
* @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Returns a `multi_future` that contains the futures for all of the tasks.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
* @tparam F The type of the function used to define the sequence.
* @tparam R The return type of the function used to define the sequence (can be `void`).
* @param first_index The first index in the sequence.
* @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned.
* @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index.
* @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined.
* @return A `multi_future` that can be used to wait for all the tasks to finish. If the sequence function returns a value, the `multi_future` can also be used to obtain the values returned by each task.
*/
template <typename T, typename F, typename R = std::invoke_result_t<std::decay_t<F>, T>>
[[nodiscard]] multi_future<R> submit_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT)
{
if (index_after_last > first_index)
{
multi_future<R> future;
future.reserve(static_cast<size_t>(index_after_last - first_index));
for (T i = first_index; i < index_after_last; ++i)
future.push_back(submit_task(
[sequence = std::forward<F>(sequence), i]
{
return sequence(i);
} BS_THREAD_POOL_PRIORITY_OUTPUT));
return future;
}
return {};
}
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief Unpause the pool. The workers will resume retrieving new tasks out of the queue. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined.
*/
void unpause()
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
paused = false;
}
task_available_cv.notify_all();
}
#endif
// Macros used internally to enable or disable pausing in the waiting and worker functions.
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
#define BS_THREAD_POOL_PAUSED_OR_EMPTY (paused || tasks.empty())
#else
#define BS_THREAD_POOL_PAUSED_OR_EMPTY tasks.empty()
#endif
/**
* @brief Wait for tasks to be completed. Normally, this function waits for all tasks, both those that are currently running in the threads and those that are still waiting in the queue. However, if the pool is paused, this function only waits for the currently running tasks (otherwise it would wait forever). Note: To wait for just one specific task, use `submit_task()` instead, and call the `wait()` member function of the generated future.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
void wait()
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
tasks_done_cv.wait(tasks_lock,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
}
/**
* @brief Wait for tasks to be completed, but stop waiting after the specified duration has passed.
*
* @tparam R An arithmetic type representing the number of ticks to wait.
* @tparam P An `std::ratio` representing the length of each tick in seconds.
* @param duration The amount of time to wait.
* @return `true` if all tasks finished running, `false` if the duration expired but some tasks are still running.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
template <typename R, typename P>
bool wait_for(const std::chrono::duration<R, P>& duration)
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
const bool status = tasks_done_cv.wait_for(tasks_lock, duration,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
return status;
}
/**
* @brief Wait for tasks to be completed, but stop waiting after the specified time point has been reached.
*
* @tparam C The type of the clock used to measure time.
* @tparam D An `std::chrono::duration` type used to indicate the time point.
* @param timeout_time The time point at which to stop waiting.
* @return `true` if all tasks finished running, `false` if the time point was reached but some tasks are still running.
*
* @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined.
*/
template <typename C, typename D>
bool wait_until(const std::chrono::time_point<C, D>& timeout_time)
{
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
if (this_thread::get_pool() == this)
throw wait_deadlock();
#endif
std::unique_lock tasks_lock(tasks_mutex);
waiting = true;
const bool status = tasks_done_cv.wait_until(tasks_lock, timeout_time,
[this]
{
return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY;
});
waiting = false;
return status;
}
#ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK
// ==============
// Public classes
// ==============
/**
* @brief An exception that will be thrown by `wait()`, `wait_for()`, and `wait_until()` if the user tries to call them from within a thread of the same pool, which would result in a deadlock.
*/
struct wait_deadlock : public std::runtime_error
{
wait_deadlock() : std::runtime_error("BS::thread_pool::wait_deadlock"){};
};
#endif
private:
// ========================
// Private member functions
// ========================
/**
* @brief Create the threads in the pool and assign a worker to each thread.
*
* @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks.
*/
void create_threads(const std::function<void()>& init_task)
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
tasks_running = thread_count;
workers_running = true;
}
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i] = std::thread(&thread_pool::worker, this, i, init_task);
}
}
/**
* @brief Destroy the threads in the pool.
*/
void destroy_threads()
{
{
const std::scoped_lock tasks_lock(tasks_mutex);
workers_running = false;
}
task_available_cv.notify_all();
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i].join();
}
}
/**
* @brief Determine how many threads the pool should have, based on the parameter passed to the constructor or reset().
*
* @param num_threads The parameter passed to the constructor or `reset()`. If the parameter is a positive number, then the pool will be created with this number of threads. If the parameter is non-positive, or a parameter was not supplied (in which case it will have the default value of 0), then the pool will be created with the total number of hardware threads available, as obtained from `std::thread::hardware_concurrency()`. If the latter returns zero for some reason, then the pool will be created with just one thread.
* @return The number of threads to use for constructing the pool.
*/
[[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t num_threads)
{
if (num_threads > 0)
return num_threads;
if (std::thread::hardware_concurrency() > 0)
return std::thread::hardware_concurrency();
return 1;
}
/**
* @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by `detach_task()` that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies `wait()` in case it is waiting.
*
* @param idx The index of this thread.
* @param init_task An initialization function to run in this thread before it starts to execute any submitted tasks.
*/
void worker(const concurrency_t idx, const std::function<void()>& init_task)
{
this_thread::get_index.index = idx;
this_thread::get_pool.pool = this;
init_task();
std::unique_lock tasks_lock(tasks_mutex);
while (true)
{
--tasks_running;
tasks_lock.unlock();
if (waiting && (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY)
tasks_done_cv.notify_all();
tasks_lock.lock();
task_available_cv.wait(tasks_lock,
[this]
{
return !BS_THREAD_POOL_PAUSED_OR_EMPTY || !workers_running;
});
if (!workers_running)
break;
{
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
const std::function<void()> task = std::move(std::remove_const_t<pr_task&>(tasks.top()).task);
tasks.pop();
#else
const std::function<void()> task = std::move(tasks.front());
tasks.pop();
#endif
++tasks_running;
tasks_lock.unlock();
task();
}
tasks_lock.lock();
}
this_thread::get_index.index = std::nullopt;
this_thread::get_pool.pool = std::nullopt;
}
// ===============
// Private classes
// ===============
/**
* @brief A helper class to divide a range into blocks. Used by `detach_blocks()`, `submit_blocks()`, `detach_loop()`, and `submit_loop()`.
*
* @tparam T The type of the indices. Should be a signed or unsigned integer.
*/
template <typename T>
class [[nodiscard]] blocks
{
public:
/**
* @brief Construct a `blocks` object with the given specifications.
*
* @param first_index_ The first index in the range.
* @param index_after_last_ The index after the last index in the range.
* @param num_blocks_ The desired number of blocks to divide the range into.
*/
blocks(const T first_index_, const T index_after_last_, const size_t num_blocks_) : first_index(first_index_), index_after_last(index_after_last_), num_blocks(num_blocks_)
{
if (index_after_last > first_index)
{
const size_t total_size = static_cast<size_t>(index_after_last - first_index);
if (num_blocks > total_size)
num_blocks = total_size;
block_size = total_size / num_blocks;
remainder = total_size % num_blocks;
if (block_size == 0)
{
block_size = 1;
num_blocks = (total_size > 1) ? total_size : 1;
}
}
else
{
num_blocks = 0;
}
}
/**
* @brief Get the first index of a block.
*
* @param block The block number.
* @return The first index.
*/
[[nodiscard]] T start(const size_t block) const
{
return first_index + static_cast<T>(block * block_size) + static_cast<T>(block < remainder ? block : remainder);
}
/**
* @brief Get the index after the last index of a block.
*
* @param block The block number.
* @return The index after the last index.
*/
[[nodiscard]] T end(const size_t block) const
{
return (block == num_blocks - 1) ? index_after_last : start(block + 1);
}
/**
* @brief Get the number of blocks. Note that this may be different than the desired number of blocks that was passed to the constructor.
*
* @return The number of blocks.
*/
[[nodiscard]] size_t get_num_blocks() const
{
return num_blocks;
}
private:
/**
* @brief The size of each block (except possibly the last block).
*/
size_t block_size = 0;
/**
* @brief The first index in the range.
*/
T first_index = 0;
/**
* @brief The index after the last index in the range.
*/
T index_after_last = 0;
/**
* @brief The number of blocks.
*/
size_t num_blocks = 0;
/**
* @brief The remainder obtained after dividing the total size by the number of blocks.
*/
size_t remainder = 0;
}; // class blocks
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
/**
* @brief A helper class to store a task with an assigned priority.
*/
class [[nodiscard]] pr_task
{
friend class thread_pool;
public:
/**
* @brief Construct a new task with an assigned priority by copying the task.
*
* @param task_ The task.
* @param priority_ The desired priority.
*/
explicit pr_task(const std::function<void()>& task_, const priority_t priority_ = 0) : task(task_), priority(priority_) {}
/**
* @brief Construct a new task with an assigned priority by moving the task.
*
* @param task_ The task.
* @param priority_ The desired priority.
*/
explicit pr_task(std::function<void()>&& task_, const priority_t priority_ = 0) : task(std::move(task_)), priority(priority_) {}
/**
* @brief Compare the priority of two tasks.
*
* @param lhs The first task.
* @param rhs The second task.
* @return `true` if the first task has a lower priority than the second task, `false` otherwise.
*/
[[nodiscard]] friend bool operator<(const pr_task& lhs, const pr_task& rhs)
{
return lhs.priority < rhs.priority;
}
private:
/**
* @brief The task.
*/
std::function<void()> task = {};
/**
* @brief The priority of the task.
*/
priority_t priority = 0;
}; // class pr_task
#endif
// ============
// Private data
// ============
#ifdef BS_THREAD_POOL_ENABLE_PAUSE
/**
* @brief A flag indicating whether the workers should pause. When set to `true`, the workers temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. When set to `false` again, the workers resume retrieving tasks.
*/
bool paused = false;
#endif
/**
* @brief A condition variable to notify `worker()` that a new task has become available.
*/
std::condition_variable task_available_cv = {};
/**
* @brief A condition variable to notify `wait()` that the tasks are done.
*/
std::condition_variable tasks_done_cv = {};
/**
* @brief A queue of tasks to be executed by the threads.
*/
#ifdef BS_THREAD_POOL_ENABLE_PRIORITY
std::priority_queue<pr_task> tasks = {};
#else
std::queue<std::function<void()>> tasks = {};
#endif
/**
* @brief A counter for the total number of currently running tasks.
*/
size_t tasks_running = 0;
/**
* @brief A mutex to synchronize access to the task queue by different threads.
*/
mutable std::mutex tasks_mutex = {};
/**
* @brief The number of threads in the pool.
*/
concurrency_t thread_count = 0;
/**
* @brief A smart pointer to manage the memory allocated for the threads.
*/
std::unique_ptr<std::thread[]> threads = nullptr;
/**
* @brief A flag indicating that `wait()` is active and expects to be notified whenever a task is done.
*/
bool waiting = false;
/**
* @brief A flag indicating to the workers to keep running. When set to `false`, the workers terminate permanently.
*/
bool workers_running = false;
}; // class thread_pool
} // namespace BS
#endif
\ No newline at end of file
......@@ -3,7 +3,9 @@
// clang-format off
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <ctime>
#include <numeric>
#include <stdexcept>
#include <string>
......@@ -21,11 +23,14 @@
#include <ankerl/unordered_dense.h>
#include <unordered_set>
#include "BS_thread_pool.h"
#include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h"
#include "ygopro-core/common.h"
#include "ygopro-core/card_data.h"
#include "ygopro-core/duel.h"
#include "ygopro-core/ocgapi.h"
// clang-format on
......@@ -892,8 +897,18 @@ public:
}
};
// TODO: 7% performance loss
static std::shared_timed_mutex duel_mtx;
struct MDuel {
intptr_t pduel;
uint64_t seed;
std::vector<CardCode> main_deck0;
std::vector<CardCode> extra_deck0;
std::string deck_name0;
std::vector<CardCode> main_deck1;
std::vector<CardCode> extra_deck1;
std::string deck_name1;
};
static std::mutex duel_mtx;
inline Card db_query_card(const SQLite::Database &db, CardCode code) {
SQLite::Statement query1(db, "SELECT * FROM datas WHERE id=?");
......@@ -1237,7 +1252,7 @@ public:
"play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false));
"record"_.Bind(false), "async_reset"_.Bind(true));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1301,6 +1316,10 @@ constexpr int32_t duel_options_ = ((rules_ & 0xFF) << 16) + (0 & 0xFFFF);
class YGOProEnv : public Env<YGOProEnvSpec> {
protected:
constexpr static int init_lp_ = 8000;
constexpr static int startcount_ = 5;
constexpr static int drawcount_ = 1;
std::string deck1_;
std::string deck2_;
std::vector<uint32> main_deck0_;
......@@ -1324,7 +1343,7 @@ protected:
PlayerId ai_player_;
intptr_t pduel_;
intptr_t pduel_ = 0;
Player *players_[2]; // abstract class must be pointer
std::uniform_int_distribution<uint64_t> dist_int_;
......@@ -1365,6 +1384,9 @@ protected:
uint64_t step_time_count_ = 0;
double reset_time_ = 0;
double reset_time_1_ = 0;
double reset_time_2_ = 0;
double reset_time_3_ = 0;
uint64_t reset_time_count_ = 0;
const int n_history_actions_;
......@@ -1403,6 +1425,14 @@ protected:
// MSG_SELECT_COUNTER
int n_counters_ = 0;
// async reset
const bool async_reset_;
int n_lives_ = 0;
std::future<MDuel> duel_fut_;
BS::thread_pool pool_;
std::mt19937 duel_gen_;
public:
YGOProEnv(const Spec &spec, int env_id)
: Env<YGOProEnvSpec>(spec, env_id),
......@@ -1412,13 +1442,21 @@ public:
player_(spec.config["player"_]), players_{nullptr, nullptr},
play_modes_(parse_play_modes(spec.config["play_mode"_])),
verbose_(spec.config["verbose"_]), record_(spec.config["record"_]),
n_history_actions_(spec.config["n_history_actions"_]) {
n_history_actions_(spec.config["n_history_actions"_]), pool_(BS::thread_pool(1)),
async_reset_(spec.config["async_reset"_]) {
if (record_) {
if (!verbose_) {
throw std::runtime_error("record mode must be used with verbose mode and num_envs=1");
}
// replay_data_ = new uint8_t[MAX_REPLAY_SIZE];
// rdata_ = replay_data_;
}
duel_gen_ = std::mt19937(dist_int_(gen_));
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, duel_seed=dist_int_(gen_)] {
return new_duel(duel_seed);
});
}
int max_options = spec.config["max_options"_];
......@@ -1452,8 +1490,36 @@ public:
play_modes_.end();
}
void update_time_stat(const clock_t& start, uint64_t time_count, double& time_stat) {
double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
time_stat = time_stat * (static_cast<double>(time_count) /
(time_count + 1)) + seconds / (time_count + 1);
}
MDuel new_duel(uint32_t seed) {
auto pduel = YGO_CreateDuel(seed);
MDuel mduel{pduel, seed};
for (PlayerId i = 0; i < 2; i++) {
YGO_SetPlayerInfo(pduel, i, init_lp_, startcount_, drawcount_);
auto [main_deck, extra_deck, deck_name] = load_deck(pduel, i, duel_gen_);
if (i == 0) {
mduel.main_deck0 = main_deck;
mduel.extra_deck0 = extra_deck;
mduel.deck_name0 = deck_name;
} else {
mduel.main_deck1 = main_deck;
mduel.extra_deck1 = extra_deck;
mduel.deck_name1 = deck_name;
}
}
YGO_StartDuel(pduel, duel_options_);
return mduel;
}
void Reset() override {
// clock_t start = clock();
clock_t start = clock();
if (random_mode()) {
play_mode_ = play_modes_[dist_int_(gen_) % play_modes_.size()];
} else {
......@@ -1476,15 +1542,26 @@ public:
ha_p_0_ = 0;
ha_p_1_ = 0;
auto duel_seed = dist_int_(gen_);
clock_t _start = clock();
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
YGO_CreateDuel(duel_seed);
ulock.unlock();
intptr_t old_duel = pduel_;
MDuel mduel;
if (async_reset_) {
mduel = duel_fut_.get();
n_lives_ = 1;
} else {
mduel = new_duel(dist_int_(gen_));
}
int init_lp = 8000;
int startcount = 5;
int drawcount = 1;
auto duel_seed = mduel.seed;
pduel_ = mduel.pduel;
deck_name_[0] = mduel.deck_name0;
deck_name_[1] = mduel.deck_name1;
main_deck0_ = mduel.main_deck0;
extra_deck0_ = mduel.extra_deck0;
main_deck1_ = mduel.main_deck1;
extra_deck1_ = mduel.extra_deck1;
for (PlayerId i = 0; i < 2; i++) {
if (players_[i] != nullptr) {
......@@ -1496,15 +1573,13 @@ public:
}
nickname_[i] = nickname;
if ((play_mode_ == kHuman) && (i != ai_player_)) {
players_[i] = new HumanPlayer(nickname_[i], init_lp, i, verbose_);
players_[i] = new HumanPlayer(nickname_[i], init_lp_, i, verbose_);
} else if (play_mode_ == kRandomBot) {
players_[i] = new RandomAI(max_options(), dist_int_(gen_), nickname_[i],
init_lp, i, verbose_);
init_lp_, i, verbose_);
} else {
players_[i] = new GreedyAI(nickname_[i], init_lp, i, verbose_);
players_[i] = new GreedyAI(nickname_[i], init_lp_, i, verbose_);
}
YGO_SetPlayerInfo(pduel_, i, init_lp, startcount, drawcount);
load_deck(i);
lp_[i] = players_[i]->init_lp_;
}
......@@ -1553,9 +1628,9 @@ public:
fwrite(name, 40, 1, fp_);
}
ReplayWriteInt32(init_lp);
ReplayWriteInt32(startcount);
ReplayWriteInt32(drawcount);
ReplayWriteInt32(init_lp_);
ReplayWriteInt32(startcount_);
ReplayWriteInt32(drawcount_);
ReplayWriteInt32(duel_options_);
for (PlayerId i = 0; i < 2; i++) {
......@@ -1573,24 +1648,34 @@ public:
}
YGO_StartDuel(pduel_, duel_options_);
duel_started_ = true;
winner_ = 255;
win_reason_ = 255;
// update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock();
next();
done_ = false;
elapsed_step_ = 0;
WriteState(0.0);
// double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
// // update reset_time by moving average
// reset_time_ = reset_time_* (static_cast<double>(reset_time_count_) /
// (reset_time_count_ + 1)) + seconds / (reset_time_count_ + 1);
if (async_reset_) {
duel_fut_ = pool_.submit_task([
this, old_duel, duel_seed=dist_int_(gen_)] {
if (old_duel != 0) {
YGO_EndDuel(old_duel);
}
return new_duel(duel_seed);
});
}
// update_time_stat(_start, reset_time_count_, reset_time_3_);
// update_time_stat(start, reset_time_count_, reset_time_);
// reset_time_count_++;
// if (reset_time_count_ % 20 == 0) {
// fmt::println("Reset time: {:.3f}", reset_time_);
// fmt::println("Reset time: {:.3f}, {:.3f}, {:.3f}", reset_time_ * 1000, reset_time_2_ * 1000, reset_time_3_ * 1000);
// }
}
......@@ -1617,7 +1702,7 @@ public:
options_.push_back(spec);
}
} else {
ms_combs_ = combs;
ms_combs_ = combs;
_callback_multi_select_2_prepare();
}
}
......@@ -1854,13 +1939,10 @@ public:
WriteState(reward, win_reason_);
// double seconds = static_cast<double>(clock() - start) / CLOCKS_PER_SEC;
// // update step_time by moving average
// step_time_ = step_time_* (static_cast<double>(step_time_count_) /
// (step_time_count_ + 1)) + seconds / (step_time_count_ + 1);
// update_time_stat(start, step_time_count_, step_time_);
// step_time_count_++;
// if (step_time_count_ % 500 == 0) {
// fmt::println("Step time: {:.3f}", step_time_);
// if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000);
// }
}
......@@ -2210,25 +2292,30 @@ private:
}
// ygopro-core API
void YGO_CreateDuel(uint32_t seed) {
intptr_t YGO_CreateDuel(uint32_t seed) {
std::mt19937 rnd(seed);
pduel_ = create_duel(rnd());
// return create_duel(rnd());
duel* pduel = new duel();
pduel->random.reset(rnd());
return (intptr_t)pduel;
}
void YGO_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) {
void YGO_SetPlayerInfo(intptr_t pduel, int32 playerid, int32 lp, int32 startcount, int32 drawcount) const {
set_player_info(pduel, playerid, lp, startcount, drawcount);
}
void YGO_NewCard(intptr_t pduel, uint32 code, uint8 owner, uint8 playerid, uint8 location, uint8 sequence, uint8 position) {
void YGO_NewCard(intptr_t pduel, uint32 code, uint8 owner, uint8 playerid, uint8 location, uint8 sequence, uint8 position) const {
new_card(pduel, code, owner, playerid, location, sequence, position);
}
void YGO_StartDuel(intptr_t pduel, int32 options) {
void YGO_StartDuel(intptr_t pduel, int32 options) const {
start_duel(pduel, options);
}
void YGO_EndDuel(intptr_t pduel) {
end_duel(pduel);
void YGO_EndDuel(intptr_t pduel) const {
// end_duel(pduel);
duel* pd = (duel*)pduel;
delete pd;
}
int32 YGO_GetMessage(intptr_t pduel, byte* buf) {
......@@ -2355,47 +2442,45 @@ private:
options_);
}
void load_deck(PlayerId player, bool shuffle = true) {
std::string deck = player == 0 ? deck1_ : deck2_;
std::vector<CardCode> &main_deck = player == 0 ? main_deck0_ : main_deck1_;
std::vector<CardCode> &extra_deck =
player == 0 ? extra_deck0_ : extra_deck1_;
std::tuple<std::vector<CardCode>, std::vector<CardCode>, std::string>
load_deck(
intptr_t pduel, PlayerId player, std::mt19937& gen, bool shuffle = true) const {
std::string deck_name = player == 0 ? deck1_ : deck2_;
if (deck == "random") {
if (deck_name == "random") {
// generate random deck name
std::uniform_int_distribution<uint64_t> dist_int(0,
deck_names_.size() - 1);
deck_name_[player] = deck_names_[dist_int(gen_)];
} else {
deck_name_[player] = deck;
deck_name = deck_names_[dist_int(gen)];
}
deck = deck_name_[player];
main_deck = main_decks_.at(deck);
extra_deck = extra_decks_.at(deck);
std::vector<CardCode> main_deck = main_decks_.at(deck_name);
std::vector<CardCode> extra_deck = extra_decks_.at(deck_name);
if (verbose_) {
fmt::println("{} {}: {}, main({}), extra({})", player, nickname_[player],
deck, main_deck.size(), extra_deck.size());
deck_name, main_deck.size(), extra_deck.size());
}
if (shuffle) {
std::shuffle(main_deck.begin(), main_deck.end(), gen_);
std::shuffle(main_deck.begin(), main_deck.end(), gen);
}
// add main deck in reverse order following ygopro
// but since we have shuffled deck, so just add in order
for (int i = 0; i < main_deck.size(); i++) {
YGO_NewCard(pduel_, main_deck[i], player, player, LOCATION_DECK, 0,
YGO_NewCard(pduel, main_deck[i], player, player, LOCATION_DECK, 0,
POS_FACEDOWN_DEFENSE);
}
// add extra deck in reverse order following ygopro
for (int i = int(extra_deck.size()) - 1; i >= 0; --i) {
YGO_NewCard(pduel_, extra_deck[i], player, player, LOCATION_EXTRA, 0,
YGO_NewCard(pduel, extra_deck[i], player, player, LOCATION_EXTRA, 0,
POS_FACEDOWN_DEFENSE);
}
return {main_deck, extra_deck, deck_name};
}
void next() {
......@@ -4573,10 +4658,11 @@ private:
void _duel_end(uint8_t player, uint8_t reason) {
winner_ = player;
win_reason_ = reason;
std::unique_lock<std::shared_timed_mutex> ulock(duel_mtx);
YGO_EndDuel(pduel_);
ulock.unlock();
if (async_reset_) {
n_lives_--;
} else {
YGO_EndDuel(pduel_);
}
duel_started_ = false;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment