Commit fcaf7bf7 authored by sbl1996@126.com's avatar sbl1996@126.com

Add nnx

parent 632f551d
This diff is collapsed.
from typing import List
import os
import shutil
from pathlib import Path
import zipfile
......@@ -16,10 +19,10 @@ class ModelCheckpoint(object):
"""
def __init__(self, dirname, save_fn, n_saved=1):
self._dirname = Path(dirname).expanduser()
self._dirname = Path(dirname).expanduser().absolute()
self._n_saved = n_saved
self._save_fn = save_fn
self._saved = []
self._saved: List[Path] = []
def _check_dir(self):
self._dirname.mkdir(parents=True, exist_ok=True)
......@@ -38,7 +41,10 @@ class ModelCheckpoint(object):
if len(self._saved) > self._n_saved:
path = self._saved.pop(0)
os.remove(path)
if path.is_dir():
shutil.rmtree(path)
else:
os.remove(path)
def get_latest(self):
path = self._saved[-1]
......
......@@ -452,13 +452,14 @@ class Actor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01))
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=self.final_init)
f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min
......@@ -471,6 +472,7 @@ class FiLMActor(nn.Module):
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact
def __call__(self, f_state, f_actions, mask):
......@@ -486,7 +488,7 @@ class FiLMActor(nn.Module):
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
kernel_init=self.final_init)(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
......@@ -647,6 +649,7 @@ class RNNAgent(nn.Module):
critic_depth: int = 3
version: int = 0
q_head: bool = False
switch: bool = True
freeze_id: bool = False
int_head: bool = False
......@@ -699,11 +702,6 @@ class RNNAgent(nn.Module):
num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step:
f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
......@@ -722,13 +720,16 @@ class RNNAgent(nn.Module):
# f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
actor_init = nn.initializers.orthogonal(1) if self.q_head else nn.initializers.orthogonal(0.01)
if self.film:
actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam, final_init=actor_init)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, final_init=actor_init)
logits = actor(f_state_r, f_actions, mask)
if self.q_head:
return rstate, logits, valid
CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth
......
......@@ -154,7 +154,7 @@ class BatchRenorm(nn.Module):
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
"""
use_running_average: Optional[bool] = None
axis: int = -1
......
This diff is collapsed.
This diff is collapsed.
import jax
import jax.numpy as jnp
from flax import nnx
default_kernel_init = nnx.initializers.lecun_normal()
default_bias_init = nnx.initializers.zeros_init()
class OptimizedLSTMCell(nnx.Module):
def __init__(
self, in_features, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 4 * features,
use_bias=False, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 4 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
c, h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
i, f, g, o = jnp.split(dense_i + dense_h, indices_or_sections=4, axis=-1)
i, f, g, o = self.gate_fn(i), self.gate_fn(f), self.activation_fn(g), self.gate_fn(o)
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
class GRUCell(nnx.Module):
def __init__(
self, in_features: int, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 3 * features,
use_bias=True, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 3 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
ir, iz, in_ = jnp.split(dense_i, indices_or_sections=3, axis=-1)
hr, hz, hn = jnp.split(dense_h, indices_or_sections=3, axis=-1)
r = self.gate_fn(ir + hr)
z = self.gate_fn(iz + hz)
n = self.activation_fn(in_ + r * hn)
new_h = (1.0 - z) * n + z * h
return new_h, new_h
This diff is collapsed.
......@@ -1528,7 +1528,7 @@ public:
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false));
"oppo_info"_.Bind(false), "max_steps"_.Bind(1000));
}
template <typename Config>
static decltype(auto) StateSpec(const Config &conf) {
......@@ -1629,6 +1629,7 @@ protected:
std::uniform_int_distribution<uint64_t> dist_int_;
bool done_{true};
long step_count_{0};
bool duel_started_{false};
uint32_t eng_flag_{0};
......@@ -1947,6 +1948,7 @@ public:
discard_hand_ = false;
done_ = false;
step_count_ = 0;
// update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock();
......@@ -2227,6 +2229,14 @@ public:
next();
}
step_count_++;
if (!done_ && (step_count_ >= spec_.config["max_steps"_])) {
PlayerId winner = lp_[0] > lp_[1] ? 0 : 1;
_duel_end(winner, 0x01);
done_ = true;
legal_actions_.clear();
}
float reward = 0;
int reason = 0;
if (done_) {
......@@ -2334,6 +2344,9 @@ public:
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1);
// if (step_count_ >= spec_.config["max_steps"_]) {
// fmt::println("Max steps reached return");
// }
return;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment