Commit 93bc3723 authored by sbl1996@126.com's avatar sbl1996@126.com

Add option for greedy_reward and correct upgo

parent 0ecf0a00
......@@ -26,7 +26,7 @@ from ygoai.utils import init_ygopro
from ygoai.rl.jax.agent2 import PPOLSTMAgent
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import compute_gae_upgo_2p0s, compute_gae_2p0s
from ygoai.rl.jax import compute_gae_2p0s, upgo_advantage
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
......@@ -62,6 +62,8 @@ class Args:
"""the maximum number of options"""
n_history_actions: int = 32
"""the number of history actions to use"""
greedy_reward: bool = True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps: int = 5000000000
"""total timesteps of the experiments"""
......@@ -117,7 +119,7 @@ class Args:
"""whether to use `jax.distirbuted`"""
concurrency: bool = True
"""whether to run the actor and learner concurrently"""
bfloat16: bool = True
bfloat16: bool = False
"""whether to use bfloat16 for the agent"""
thread_affinity: bool = False
"""whether to use thread affinity for the environment"""
......@@ -161,6 +163,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options=args.max_options,
n_history_actions=args.n_history_actions,
async_reset=False,
greedy_reward=args.greedy_reward,
play_mode=mode,
)
envs.num_envs = num_envs
......@@ -596,10 +599,12 @@ if __name__ == "__main__":
(jax.lax.stop_gradient(new_values), rewards, next_dones, switch),
)
compute_gae_fn = compute_gae_upgo_2p0s if args.upgo else compute_gae_2p0s
advantages, target_values = compute_gae_fn(
advantages, target_values = compute_gae_2p0s(
next_value, values, rewards, next_dones, switch,
args.gamma, args.gae_lambda)
if args.upgo:
advantages = advantages + upgo_advantage(
next_value, values, rewards, next_dones, switch, args.gamma)
advantages, target_values = jax.tree.map(
lambda x: jnp.reshape(x, (-1,)), (advantages, target_values))
......
......@@ -67,27 +67,6 @@ def vtrace(
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def upgo_return(r_t, v_t, discount_t, stop_target_gradients: bool = True):
def _body(acc, xs):
r, v, q, discount = xs
acc = r + discount * jnp.where(q >= v, acc, v)
return acc, acc
# TODO: following alphastar, estimate q_t with one-step target
# It might be better to use network to estimate q_t
q_t = r_t[1:] + discount_t[1:] * v_t[1:] # q[:-1]
_, returns = jax.lax.scan(
_body, q_t[-1], (r_t[:-1], v_t[:-1], q_t, discount_t[:-1]), reverse=True)
# Following rlax.vtrace_td_error_and_advantage, part of gradient is reserved
# Experiments show that where to stop gradient has no impact on the performance
returns = jax.lax.select(
stop_target_gradients, jax.lax.stop_gradient(returns), returns)
returns = jnp.concatenate([returns, q_t[-1:]], axis=0)
return returns
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
......@@ -123,39 +102,112 @@ def compute_gae_2p0s(
return advantages, target_values
@partial(jax.jit, static_argnums=(5, 6))
def compute_gae_upgo_2p0s(
next_value, values, rewards, next_dones, switch,
gamma, gae_lambda,
):
@partial(jax.jit, static_argnums=(5,))
def upgo_advantage(
next_value, values, rewards, next_dones, switch, gamma):
def body_fn(carry, inp):
boot_value, boot_done, next_value, next_q, last_return, lastgaelam = carry
boot_value, boot_done, next_value, next_q, last_return = carry
next_done, cur_value, reward, switch = inp
next_done = jnp.where(switch, boot_done, next_done)
next_value = jnp.where(switch, -boot_value, next_value)
next_q = jnp.where(switch, -boot_value * gamma, next_q)
last_return = jnp.where(switch, -boot_value, last_return)
lastgaelam = jnp.where(switch, 0, lastgaelam)
gamma_ = gamma * (1.0 - next_done)
last_return = reward + gamma_ * jnp.where(
next_q >= next_value, last_return, next_value)
next_q = reward + gamma_ * next_value
delta = next_q - cur_value
lastgaelam = delta + gae_lambda * gamma_ * lastgaelam
carry = boot_value, boot_done, cur_value, next_q, last_return, lastgaelam
return carry, (lastgaelam, last_return)
carry = boot_value, boot_done, cur_value, next_q, last_return
return carry, last_return
next_done = next_dones[-1]
lastgaelam = jnp.zeros_like(next_value)
carry = next_value, next_done, next_value, next_value, next_value, lastgaelam
carry = next_value, next_done, next_value, next_value, next_value
_, (advantages, returns) = jax.lax.scan(
_, returns = jax.lax.scan(
body_fn, carry, (next_dones, values, rewards, switch), reverse=True
)
return returns - values, advantages + values
return returns - values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def compute_gae_once(carry, inp, gamma, gae_lambda):
......
......@@ -336,16 +336,17 @@ class PPOLSTMAgent(nn.Module):
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
multi_step: bool = False
switch: bool = True
@nn.compact
def __call__(self, inputs):
if self.multi_step:
# (num_steps * batch_size, ...)
carry1, carry2, x, done, switch = inputs
batch_size = carry1[0].shape[0]
rstate1, rstate2, x, done, switch_or_main = inputs
batch_size = rstate1[0].shape[0]
num_steps = done.shape[0] // batch_size
else:
carry, x = inputs
rstate, x = inputs
c = self.channels
encoder = Encoder(
......@@ -361,21 +362,31 @@ class PPOLSTMAgent(nn.Module):
lstm_layer = nn.OptimizedLSTMCell(
self.lstm_channels, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=nn.initializers.orthogonal(1.0))
if self.multi_step:
def body_fn(cell, carry, x, done, switch):
carry, init_carry = carry
carry, y = cell(carry, x)
carry = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), carry)
carry = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_carry, carry)
return (carry, init_carry), y
if self.switch:
def body_fn(cell, carry, x, done, switch):
rstate, init_rstate2 = carry
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate = jax.tree.map(lambda x, y: jnp.where(switch[:, None], x, y), init_rstate2, rstate)
return (rstate, init_rstate2), y
else:
def body_fn(cell, carry, x, done, main):
rstate1, rstate2 = carry
rstate = jax.tree.map(lambda x1, x2: jnp.where(main[:, None], x1, x2), rstate1, rstate2)
rstate, y = cell(rstate, x)
rstate = jax.tree.map(lambda x: jnp.where(done[:, None], 0, x), rstate)
rstate1 = jax.tree.map(lambda x, y: jnp.where(main[:, None], x, y), rstate, rstate1)
rstate2 = jax.tree.map(lambda x, y: jnp.where(main[:, None], y, x), rstate, rstate2)
return (rstate1, rstate2), y
scan = nn.scan(
body_fn, variable_broadcast='params',
split_rngs={'params': False})
f_state, done, switch = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch))
carry, f_state = scan(lstm_layer, (carry1, carry2), f_state, done, switch)
f_state, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
rstate, f_state = scan(lstm_layer, (rstate1, rstate2), f_state, done, switch_or_main)
f_state = f_state.reshape((-1, f_state.shape[-1]))
else:
carry, f_state = lstm_layer(carry, f_state)
rstate, f_state = lstm_layer(rstate, f_state)
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
......@@ -384,4 +395,4 @@ class PPOLSTMAgent(nn.Module):
logits = actor(f_state, f_actions, mask)
value = critic(f_state)
return carry, logits, value, valid
return rstate, logits, value, valid
......@@ -1252,7 +1252,7 @@ public:
"play_mode"_.Bind(std::string("bot")),
"verbose"_.Bind(false), "max_options"_.Bind(16),
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(true));
"record"_.Bind(false), "async_reset"_.Bind(true), "greedy_reward"_.Bind(true));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1353,6 +1353,7 @@ protected:
PlayerId winner_;
uint8_t win_reason_;
bool greedy_reward_;
int lp_[2];
......@@ -1438,7 +1439,7 @@ public:
play_modes_(parse_play_modes(spec.config["play_mode"_])),
verbose_(spec.config["verbose"_]), record_(spec.config["record"_]),
n_history_actions_(spec.config["n_history_actions"_]), pool_(BS::thread_pool(1)),
async_reset_(spec.config["async_reset"_]) {
async_reset_(spec.config["async_reset"_]), greedy_reward_(spec.config["greedy_reward"_]) {
if (record_) {
if (!verbose_) {
throw std::runtime_error("record mode must be used with verbose mode and num_envs=1");
......@@ -1879,29 +1880,33 @@ public:
int reason = 0;
if (done_) {
float base_reward;
if (winner_ == 0) {
if (turn_count_ <= 1) {
// FTK
base_reward = 16.0;
} else if (turn_count_ <= 3) {
base_reward = 8.0;
} else if (turn_count_ <= 5) {
base_reward = 4.0;
} else if (turn_count_ <= 7) {
base_reward = 2.0;
if (greedy_reward_) {
if (winner_ == 0) {
if (turn_count_ <= 1) {
// FTK
base_reward = 16.0;
} else if (turn_count_ <= 3) {
base_reward = 8.0;
} else if (turn_count_ <= 5) {
base_reward = 4.0;
} else if (turn_count_ <= 7) {
base_reward = 2.0;
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 7);
}
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 7);
if (turn_count_ <= 1) {
base_reward = 8.0;
} else if (turn_count_ <= 3) {
base_reward = 4.0;
} else if (turn_count_ <= 5) {
base_reward = 2.0;
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 5);
}
}
} else {
if (turn_count_ <= 1) {
base_reward = 8.0;
} else if (turn_count_ <= 3) {
base_reward = 4.0;
} else if (turn_count_ <= 5) {
base_reward = 2.0;
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 5);
}
base_reward = 1.0;
}
if (play_mode_ == kSelfPlay) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment