Commit 096e743e authored by sbl1996@126.com's avatar sbl1996@126.com

Update agent and UPGO

parent 4f2ad15b
This diff is collapsed.
...@@ -18,9 +18,12 @@ class RecordEpisodeStatistics(gym.Wrapper): ...@@ -18,9 +18,12 @@ class RecordEpisodeStatistics(gym.Wrapper):
return observations, infos return observations, infos
def step(self, action): def step(self, action):
observations, rewards, terminated, truncated, infos = super().step(action) return self.update_stats_and_infos(*super().step(action))
def update_stats_and_infos(self, *args):
observations, rewards, terminated, truncated, infos = args
dones = np.logical_or(terminated, truncated) dones = np.logical_or(terminated, truncated)
self.episode_returns += rewards self.episode_returns += infos.get("reward", rewards)
self.episode_lengths += 1 self.episode_lengths += 1
self.returned_episode_returns = np.where( self.returned_episode_returns = np.where(
dones, self.episode_returns, self.returned_episode_returns dones, self.episode_returns, self.returned_episode_returns
...@@ -32,6 +35,19 @@ class RecordEpisodeStatistics(gym.Wrapper): ...@@ -32,6 +35,19 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_lengths *= 1 - dones self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths infos["l"] = self.returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return ( return (
observations, observations,
rewards, rewards,
...@@ -39,6 +55,19 @@ class RecordEpisodeStatistics(gym.Wrapper): ...@@ -39,6 +55,19 @@ class RecordEpisodeStatistics(gym.Wrapper):
infos, infos,
) )
def async_reset(self):
self.env.async_reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def recv(self):
return self.update_stats_and_infos(*self.env.recv())
def send(self, action):
return self.env.send(action)
class CompatEnv(gym.Wrapper): class CompatEnv(gym.Wrapper):
......
from functools import partial
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
q_estimate: jnp.ndarray
errors: jnp.ndarray
def vtrace(
v_tm1,
v_t,
r_t,
discount_t,
rho_tm1,
lambda_=1.0,
c_clip_min: float = 0.001,
c_clip_max: float = 1.007,
rho_clip_min: float = 0.001,
rho_clip_max: float = 1.007,
stop_target_gradients: bool = True,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_ = jnp.ones_like(discount_t) * lambda_
c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# Compute the temporal difference errors.
td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# Work backwards computing the td-errors.
def _body(acc, xs):
td_error, discount, c = xs
acc = td_error + discount * c * acc
return acc, acc
_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors = jax.lax.select(
stop_target_gradients,
jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
errors)
targets_tm1 = errors + v_tm1
q_bootstrap = jnp.concatenate([
lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
v_t[-1:],
], axis=0)
q_estimate = r_t + discount_t * q_bootstrap
return VTraceOutput(q_estimate=q_estimate, errors=errors)
def upgo_return(r_t, v_t, discount_t, stop_target_gradients: bool = True):
def _body(acc, xs):
r, v, q, discount = xs
acc = r + discount * jnp.where(q >= v, acc, v)
return acc, acc
# TODO: following alphastar, estimate q_t with one-step target
# It might be better to use network to estimate q_t
q_t = r_t[1:] + discount_t[1:] * v_t[1:] # q[:-1]
_, returns = jax.lax.scan(
_body, q_t[-1], (r_t[:-1], v_t[:-1], q_t, discount_t[:-1]), reverse=True)
# Following rlax.vtrace_td_error_and_advantage, part of gradient is reserved
# Experiments show that where to stop gradient has no impact on the performance
returns = jax.lax.select(
stop_target_gradients, jax.lax.stop_gradient(returns), returns)
returns = jnp.concatenate([returns, q_t[-1:]], axis=0)
return returns
def clipped_surrogate_pg_loss(prob_ratios_t, adv_t, mask, epsilon, use_stop_gradient=True):
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
clipped_ratios_t = jnp.clip(prob_ratios_t, 1. - epsilon, 1. + epsilon)
clipped_objective = jnp.fmin(prob_ratios_t * adv_t, clipped_ratios_t * adv_t)
return -jnp.mean(clipped_objective * mask)
def compute_gae_once(carry, inp, gamma, gae_lambda):
nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
delta1 = reward1 + gamma * nextvalues1 - curvalues
delta2 = reward2 + gamma * nextvalues2 - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, advantages
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, advantages = jax.lax.scan(
partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
target_values = advantages + values
return advantages, target_values
def compute_gae_once_upgo(carry, inp, gamma, gae_lambda):
next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2 = carry
next_done, curvalues, reward, learn = inp
learn1 = learn
learn2 = ~learn
factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
real_done1 = next_done | ~done_used1
next_value1 = jnp.where(real_done1, 0, next_value1)
last_return1 = jnp.where(real_done1, 0, last_return1)
lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
real_done2 = next_done | ~done_used2
next_value2 = jnp.where(real_done2, 0, next_value2)
last_return2 = jnp.where(real_done2, 0, last_return2)
lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
done_used1 = jnp.where(
next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
done_used2 = jnp.where(
next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
last_return1_ = reward1 + gamma * jnp.where(
next_q1 >= next_value1, last_return1, next_value1)
last_return2_ = reward2 + gamma * jnp.where(
next_q2 >= next_value2, last_return2, next_value2)
next_q1_ = reward1 + gamma * next_value1
next_q2_ = reward2 + gamma * next_value2
delta1 = next_q1_ - curvalues
delta2 = next_q2_ - curvalues
lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
returns = jnp.where(learn1, last_return1_, last_return2_)
advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
next_value1 = jnp.where(learn1, curvalues, next_value1)
next_value2 = jnp.where(learn2, curvalues, next_value2)
lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
next_q1 = jnp.where(learn1, next_q1_, next_q1)
next_q2 = jnp.where(learn2, next_q2_, next_q1)
last_return1 = jnp.where(learn1, last_return1_, last_return1)
last_return2 = jnp.where(learn2, last_return2_, last_return2)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
return carry, (advantages, returns)
@partial(jax.jit, static_argnums=(7, 8))
def compute_gae_upgo(
next_value, next_done, next_learn,
values, rewards, dones, learns,
gamma, gae_lambda,
):
next_value1 = jnp.where(next_learn, next_value, -next_value)
next_value2 = -next_value1
last_return1 = next_q1 = next_value1
last_return2 = next_q2 = next_value2
done_used1 = jnp.ones_like(next_done)
done_used2 = jnp.ones_like(next_done)
reward1 = jnp.zeros_like(next_value)
reward2 = jnp.zeros_like(next_value)
lastgaelam1 = jnp.zeros_like(next_value)
lastgaelam2 = jnp.zeros_like(next_value)
carry = next_value1, next_value2, next_q1, next_q2, last_return1, last_return2, \
done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
_, (advantages, returns) = jax.lax.scan(
partial(compute_gae_once_upgo, gamma=gamma, gae_lambda=gae_lambda),
carry, (dones[1:], values, rewards, learns), reverse=True
)
return returns - values, advantages + values
...@@ -5,54 +5,13 @@ import jax ...@@ -5,54 +5,13 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding from ygoai.rl.jax.transformer import EncoderLayer, DecoderLayer, PositionalEncoding
def decode_id(x): default_embed_init = nn.initializers.uniform(scale=0.001)
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=32000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
default_embed_init = nn.initializers.uniform(scale=0.0001)
default_fc_init1 = nn.initializers.uniform(scale=0.001) default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.0001) default_fc_init2 = nn.initializers.uniform(scale=0.001)
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=self.kernel_init, use_bias=False)(x)
if i < n - 1 or not self.last_lin:
x = nn.relu(x)
return x
class ActionEncoder(nn.Module): class ActionEncoder(nn.Module):
...@@ -105,18 +64,19 @@ class Encoder(nn.Module): ...@@ -105,18 +64,19 @@ class Encoder(nn.Module):
layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False) layer_norm = partial(nn.LayerNorm, use_scale=False, use_bias=False)
embed = partial( embed = partial(
nn.Embed, dtype=self.dtype, param_dtype=self.param_dtype, embedding_init=default_embed_init) nn.Embed, dtype=jnp.float32, param_dtype=self.param_dtype, embedding_init=default_embed_init)
fc_layer = partial(nn.Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype) fc_embed = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
fc_layer = partial(nn.Dense, use_bias=False, dtype=jnp.float32, param_dtype=self.param_dtype)
id_embed = embed(n_embed, embed_dim) id_embed = embed(n_embed, embed_dim)
count_embed = embed(100, c // 16) count_embed = embed(100, c // 16)
hand_count_embed = embed(100, c // 16) hand_count_embed = embed(100, c // 16)
num_fc = MLP((c // 8,), last_lin=False, dtype=self.dtype, param_dtype=self.param_dtype) num_fc = MLP((c // 8,), last_lin=False, dtype=jnp.float32, param_dtype=self.param_dtype)
bin_points, bin_intervals = make_bin_params(n_bins=32) bin_points, bin_intervals = make_bin_params(n_bins=32)
num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals)) num_transform = lambda x: num_fc(bytes_to_bin(x, bin_points, bin_intervals))
action_encoder = ActionEncoder(channels=c, dtype=self.dtype, param_dtype=self.param_dtype) action_encoder = ActionEncoder(channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
x_cards = x['cards_'] x_cards = x['cards_']
x_global = x['global_'] x_global = x['global_']
x_actions = x['actions_'] x_actions = x['actions_']
...@@ -125,12 +85,12 @@ class Encoder(nn.Module): ...@@ -125,12 +85,12 @@ class Encoder(nn.Module):
valid = x_global[:, -1] == 0 valid = x_global[:, -1] == 0
x_cards_1 = x_cards[:, :, :12].astype(jnp.int32) x_cards_1 = x_cards[:, :, :12].astype(jnp.int32)
x_cards_2 = x_cards[:, :, 12:].astype(self.dtype or jnp.float32) x_cards_2 = x_cards[:, :, 12:].astype(jnp.float32)
x_id = decode_id(x_cards_1[:, :, :2]) x_id = decode_id(x_cards_1[:, :, :2])
x_id = id_embed(x_id) x_id = id_embed(x_id)
x_id = MLP( x_id = MLP(
(c, c // 4), dtype=self.dtype, param_dtype=self.param_dtype, (c, c // 4), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(x_id) kernel_init=default_fc_init2)(x_id)
x_id = layer_norm()(x_id) x_id = layer_norm()(x_id)
...@@ -152,10 +112,10 @@ class Encoder(nn.Module): ...@@ -152,10 +112,10 @@ class Encoder(nn.Module):
x_negated = embed(3, c // 16)(x_cards_1[:, :, 11]) x_negated = embed(3, c // 16)(x_cards_1[:, :, 11])
x_atk = num_transform(x_cards_2[:, :, 0:2]) x_atk = num_transform(x_cards_2[:, :, 0:2])
x_atk = fc_layer(c // 16, kernel_init=default_fc_init1)(x_atk) x_atk = fc_embed(c // 16, kernel_init=default_fc_init1)(x_atk)
x_def = num_transform(x_cards_2[:, :, 2:4]) x_def = num_transform(x_cards_2[:, :, 2:4])
x_def = fc_layer(c // 16, kernel_init=default_fc_init1)(x_def) x_def = fc_embed(c // 16, kernel_init=default_fc_init1)(x_def)
x_type = fc_layer(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:]) x_type = fc_embed(c // 16 * 2, kernel_init=default_fc_init2)(x_cards_2[:, :, 4:])
x_feat = jnp.concatenate([ x_feat = jnp.concatenate([
x_owner, x_position, x_overley, x_attribute, x_owner, x_position, x_overley, x_attribute,
...@@ -173,14 +133,14 @@ class Encoder(nn.Module): ...@@ -173,14 +133,14 @@ class Encoder(nn.Module):
'na_card_embed', 'na_card_embed',
lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02, lambda key, shape, dtype: jax.random.normal(key, shape, dtype) * 0.02,
(1, c), self.param_dtype) (1, c), self.param_dtype)
f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)) f_na_card = jnp.tile(na_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_na_card, f_cards], axis=1) f_cards = jnp.concatenate([f_na_card, f_cards], axis=1)
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1) c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
f_cards = layer_norm()(f_cards) f_cards = layer_norm()(f_cards)
x_global_1 = x_global[:, :4].astype(self.dtype or jnp.float32) x_global_1 = x_global[:, :4].astype(jnp.float32)
x_g_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2])) x_g_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = fc_layer(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4])) x_g_oppo_lp = fc_embed(c // 4, kernel_init=default_fc_init2)(num_transform(x_global_1[:, 2:4]))
x_global_2 = x_global[:, 4:8].astype(jnp.int32) x_global_2 = x_global[:, 4:8].astype(jnp.int32)
x_g_turn = embed(20, c // 8)(x_global_2[:, 0]) x_g_turn = embed(20, c // 8)(x_global_2[:, 0])
...@@ -197,7 +157,7 @@ class Encoder(nn.Module): ...@@ -197,7 +157,7 @@ class Encoder(nn.Module):
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn, x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1) x_g_cs, x_g_my_hand_c, x_g_op_hand_c], axis=-1)
x_global = layer_norm()(x_global) x_global = layer_norm()(x_global)
f_global = x_global + MLP((c * 2, c * 2), dtype=self.dtype, param_dtype=self.param_dtype)(x_global) f_global = x_global + MLP((c * 2, c * 2), dtype=jnp.float32, param_dtype=self.param_dtype)(x_global)
f_global = fc_layer(c)(f_global) f_global = fc_layer(c)(f_global)
f_global = layer_norm()(f_global) f_global = layer_norm()(f_global)
...@@ -220,14 +180,14 @@ class Encoder(nn.Module): ...@@ -220,14 +180,14 @@ class Encoder(nn.Module):
f_actions, f_cards, f_actions, f_cards,
tgt_key_padding_mask=a_mask, tgt_key_padding_mask=a_mask,
memory_key_padding_mask=c_mask) memory_key_padding_mask=c_mask)
x_h_actions = x['h_actions_'].astype(jnp.int32) x_h_actions = x['h_actions_'].astype(jnp.int32)
h_mask = x_h_actions[:, :, 2] == 0 # msg == 0 h_mask = x_h_actions[:, :, 2] == 0 # msg == 0
h_mask = h_mask.at[:, 0].set(False) h_mask = h_mask.at[:, 0].set(False)
x_h_id = decode_id(x_h_actions[..., :2]) x_h_id = decode_id(x_h_actions[..., :2])
x_h_id = MLP( x_h_id = MLP(
(c, c), dtype=self.dtype, param_dtype=self.param_dtype, (c, c), dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=default_fc_init2)(id_embed(x_h_id)) kernel_init=default_fc_init2)(id_embed(x_h_id))
x_h_a_feats = action_encoder(x_h_actions[:, :, 2:]) x_h_a_feats = action_encoder(x_h_actions[:, :, 2:])
...@@ -237,9 +197,9 @@ class Encoder(nn.Module): ...@@ -237,9 +197,9 @@ class Encoder(nn.Module):
for _ in range(self.num_action_layers): for _ in range(self.num_action_layers):
f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( f_h_actions = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_h_actions, src_key_padding_mask=h_mask) f_h_actions, src_key_padding_mask=h_mask)
for _ in range(self.num_action_layers): for _ in range(self.num_action_layers):
f_actions = DecoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)( f_actions = DecoderLayer(num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(
f_actions, f_h_actions, f_actions, f_h_actions,
tgt_key_padding_mask=a_mask, tgt_key_padding_mask=a_mask,
memory_key_padding_mask=h_mask) memory_key_padding_mask=h_mask)
...@@ -261,11 +221,12 @@ class Actor(nn.Module): ...@@ -261,11 +221,12 @@ class Actor(nn.Module):
@nn.compact @nn.compact
def __call__(self, f_actions, mask): def __call__(self, f_actions, mask):
c = self.channels c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
num_heads = max(2, c // 128) num_heads = max(2, c // 128)
f_actions = EncoderLayer( f_actions = EncoderLayer(
num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask) num_heads, dtype=jnp.float32, param_dtype=self.param_dtype)(f_actions, src_key_padding_mask=mask)
logits = MLP((c // 4, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_actions) logits = mlp((c // 4, 1), use_bias=True)(f_actions)
logits = logits[..., 0].astype(jnp.float32) logits = logits[..., 0]
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits) logits = jnp.where(mask, big_neg, logits)
return logits return logits
...@@ -279,8 +240,8 @@ class Critic(nn.Module): ...@@ -279,8 +240,8 @@ class Critic(nn.Module):
@nn.compact @nn.compact
def __call__(self, f_state): def __call__(self, f_state):
c = self.channels c = self.channels
x = MLP((c // 2, 1), dtype=self.dtype, param_dtype=self.param_dtype)(f_state) mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(1.0))
x = x.astype(jnp.float32) x = MLP((c // 2, 1), use_bias=True)(f_state)
return x return x
......
This diff is collapsed.
from typing import Tuple, Union, Optional
import jax.numpy as jnp
import flax.linen as nn
def decode_id(x):
x = x[..., 0] * 256 + x[..., 1]
return x
def bytes_to_bin(x, points, intervals):
points = points.astype(x.dtype)
intervals = intervals.astype(x.dtype)
x = decode_id(x)
x = jnp.expand_dims(x, -1)
return jnp.clip((x - points + intervals) / intervals, 0, 1)
def make_bin_params(x_max=12000, n_bins=32, sig_bins=24):
x_max1 = 8000
x_max2 = x_max
points1 = jnp.linspace(0, x_max1, sig_bins + 1, dtype=jnp.float32)[1:]
points2 = jnp.linspace(x_max1, x_max2, n_bins - sig_bins + 1, dtype=jnp.float32)[1:]
points = jnp.concatenate([points1, points2], axis=0)
intervals = jnp.concatenate([points[0:1], points[1:] - points[:-1]], axis=0)
return points, intervals
class MLP(nn.Module):
features: Tuple[int, ...] = (128, 128)
last_lin: bool = True
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(self, x):
n = len(self.features)
for i, c in enumerate(self.features):
if self.last_lin and i == n - 1:
kernel_init = self.last_kernel_init
else:
kernel_init = self.kernel_init
x = nn.Dense(
c, dtype=self.dtype, param_dtype=self.param_dtype,
kernel_init=kernel_init, use_bias=self.use_bias)(x)
if i < n - 1 or not self.last_lin:
x = nn.leaky_relu(x, negative_slope=0.1)
return x
...@@ -632,6 +632,7 @@ class EncoderLayer(nn.Module): ...@@ -632,6 +632,7 @@ class EncoderLayer(nn.Module):
@nn.compact @nn.compact
def __call__(self, inputs, src_key_padding_mask=None): def __call__(self, inputs, src_key_padding_mask=None):
inputs = jnp.asarray(inputs, self.dtype)
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon, x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs) dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention( x = MultiheadAttention(
......
This diff is collapsed.
...@@ -55,7 +55,7 @@ def masked_normalize(x, valid, eps=1e-8): ...@@ -55,7 +55,7 @@ def masked_normalize(x, valid, eps=1e-8):
return (x - mean) / std return (x - mean) / std
def to_tensor(x, device, dtype=torch.float32): def to_tensor(x, device, dtype=None):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
......
import envpool2
print(envpool2.list_all_envs())
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment