Commit fcaf7bf7 authored by sbl1996@126.com's avatar sbl1996@126.com

Add nnx

parent 632f551d
This diff is collapsed.
from typing import List
import os import os
import shutil
from pathlib import Path from pathlib import Path
import zipfile import zipfile
...@@ -16,10 +19,10 @@ class ModelCheckpoint(object): ...@@ -16,10 +19,10 @@ class ModelCheckpoint(object):
""" """
def __init__(self, dirname, save_fn, n_saved=1): def __init__(self, dirname, save_fn, n_saved=1):
self._dirname = Path(dirname).expanduser() self._dirname = Path(dirname).expanduser().absolute()
self._n_saved = n_saved self._n_saved = n_saved
self._save_fn = save_fn self._save_fn = save_fn
self._saved = [] self._saved: List[Path] = []
def _check_dir(self): def _check_dir(self):
self._dirname.mkdir(parents=True, exist_ok=True) self._dirname.mkdir(parents=True, exist_ok=True)
...@@ -38,6 +41,9 @@ class ModelCheckpoint(object): ...@@ -38,6 +41,9 @@ class ModelCheckpoint(object):
if len(self._saved) > self._n_saved: if len(self._saved) > self._n_saved:
path = self._saved.pop(0) path = self._saved.pop(0)
if path.is_dir():
shutil.rmtree(path)
else:
os.remove(path) os.remove(path)
def get_latest(self): def get_latest(self):
......
...@@ -452,13 +452,14 @@ class Actor(nn.Module): ...@@ -452,13 +452,14 @@ class Actor(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact @nn.compact
def __call__(self, f_state, f_actions, mask): def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype) f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype) f_actions = f_actions.astype(self.dtype)
c = self.channels c = self.channels
mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=nn.initializers.orthogonal(0.01)) mlp = partial(MLP, dtype=jnp.float32, param_dtype=self.param_dtype, last_kernel_init=self.final_init)
f_state = mlp((c,), use_bias=True)(f_state) f_state = mlp((c,), use_bias=True)(f_state)
logits = jnp.einsum('bc,bnc->bn', f_state, f_actions) logits = jnp.einsum('bc,bnc->bn', f_state, f_actions)
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
...@@ -471,6 +472,7 @@ class FiLMActor(nn.Module): ...@@ -471,6 +472,7 @@ class FiLMActor(nn.Module):
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
noam: bool = False noam: bool = False
final_init: nn.initializers.Initializer = nn.initializers.orthogonal(0.01)
@nn.compact @nn.compact
def __call__(self, f_state, f_actions, mask): def __call__(self, f_state, f_actions, mask):
...@@ -486,7 +488,7 @@ class FiLMActor(nn.Module): ...@@ -486,7 +488,7 @@ class FiLMActor(nn.Module):
f_actions, mask, a_s, a_b, o_s, o_b) f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype, logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0] kernel_init=self.final_init)(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits) logits = jnp.where(mask, big_neg, logits)
return logits return logits
...@@ -647,6 +649,7 @@ class RNNAgent(nn.Module): ...@@ -647,6 +649,7 @@ class RNNAgent(nn.Module):
critic_depth: int = 3 critic_depth: int = 3
version: int = 0 version: int = 0
q_head: bool = False
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
int_head: bool = False int_head: bool = False
...@@ -699,11 +702,6 @@ class RNNAgent(nn.Module): ...@@ -699,11 +702,6 @@ class RNNAgent(nn.Module):
num_steps = f_state.shape[0] // batch_size num_steps = f_state.shape[0] // batch_size
multi_step = num_steps > 1 multi_step = num_steps > 1
if done is not None:
assert switch_or_main is not None
else:
assert not multi_step
if multi_step: if multi_step:
f_state_r, done, switch_or_main = jax.tree.map( f_state_r, done, switch_or_main = jax.tree.map(
lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main)) lambda x: jnp.reshape(x, (num_steps, batch_size) + x.shape[1:]), (f_state, done, switch_or_main))
...@@ -722,13 +720,16 @@ class RNNAgent(nn.Module): ...@@ -722,13 +720,16 @@ class RNNAgent(nn.Module):
# f_state_r = ReZero(channel_wise=True)(f_state_r) # f_state_r = ReZero(channel_wise=True)(f_state_r)
f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1) f_state_r = jnp.concatenate([f_state, f_state_r], axis=-1)
actor_init = nn.initializers.orthogonal(1) if self.q_head else nn.initializers.orthogonal(0.01)
if self.film: if self.film:
actor = FiLMActor( actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam, final_init=actor_init)
else: else:
actor = Actor( actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, final_init=actor_init)
logits = actor(f_state_r, f_actions, mask) logits = actor(f_state_r, f_actions, mask)
if self.q_head:
return rstate, logits, valid
CriticCls = CrossCritic if self.batch_norm else Critic CriticCls = CrossCritic if self.batch_norm else Critic
cs = [self.critic_width] * self.critic_depth cs = [self.critic_width] * self.critic_depth
......
This diff is collapsed.
This diff is collapsed.
import jax
import jax.numpy as jnp
from flax import nnx
default_kernel_init = nnx.initializers.lecun_normal()
default_bias_init = nnx.initializers.zeros_init()
class OptimizedLSTMCell(nnx.Module):
def __init__(
self, in_features, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 4 * features,
use_bias=False, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 4 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
c, h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
i, f, g, o = jnp.split(dense_i + dense_h, indices_or_sections=4, axis=-1)
i, f, g, o = self.gate_fn(i), self.gate_fn(f), self.activation_fn(g), self.gate_fn(o)
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
class GRUCell(nnx.Module):
def __init__(
self, in_features: int, features: int, *,
gate_fn=nnx.sigmoid, activation_fn=nnx.tanh,
kernel_init=default_kernel_init, bias_init=default_bias_init,
recurrent_kernel_init=nnx.initializers.orthogonal(),
dtype=None, param_dtype=jnp.float32, rngs,
):
self.features = features
self.gate_fn = gate_fn
self.activation_fn = activation_fn
self.fc_i = nnx.Linear(
in_features, 3 * features,
use_bias=True, kernel_init=kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
self.fc_h = nnx.Linear(
features, 3 * features,
use_bias=True, kernel_init=recurrent_kernel_init,
bias_init=bias_init, dtype=dtype,
param_dtype=param_dtype, rngs=rngs,
)
def __call__(self, carry, inputs):
h = carry
dense_i = self.fc_i(inputs)
dense_h = self.fc_h(h)
ir, iz, in_ = jnp.split(dense_i, indices_or_sections=3, axis=-1)
hr, hz, hn = jnp.split(dense_h, indices_or_sections=3, axis=-1)
r = self.gate_fn(ir + hr)
z = self.gate_fn(iz + hz)
n = self.activation_fn(in_ + r * hn)
new_h = (1.0 - z) * n + z * h
return new_h, new_h
This diff is collapsed.
...@@ -1528,7 +1528,7 @@ public: ...@@ -1528,7 +1528,7 @@ public:
"max_cards"_.Bind(80), "n_history_actions"_.Bind(16), "max_cards"_.Bind(80), "n_history_actions"_.Bind(16),
"record"_.Bind(false), "async_reset"_.Bind(false), "record"_.Bind(false), "async_reset"_.Bind(false),
"greedy_reward"_.Bind(true), "timeout"_.Bind(600), "greedy_reward"_.Bind(true), "timeout"_.Bind(600),
"oppo_info"_.Bind(false)); "oppo_info"_.Bind(false), "max_steps"_.Bind(1000));
} }
template <typename Config> template <typename Config>
static decltype(auto) StateSpec(const Config &conf) { static decltype(auto) StateSpec(const Config &conf) {
...@@ -1629,6 +1629,7 @@ protected: ...@@ -1629,6 +1629,7 @@ protected:
std::uniform_int_distribution<uint64_t> dist_int_; std::uniform_int_distribution<uint64_t> dist_int_;
bool done_{true}; bool done_{true};
long step_count_{0};
bool duel_started_{false}; bool duel_started_{false};
uint32_t eng_flag_{0}; uint32_t eng_flag_{0};
...@@ -1947,6 +1948,7 @@ public: ...@@ -1947,6 +1948,7 @@ public:
discard_hand_ = false; discard_hand_ = false;
done_ = false; done_ = false;
step_count_ = 0;
// update_time_stat(_start, reset_time_count_, reset_time_2_); // update_time_stat(_start, reset_time_count_, reset_time_2_);
// _start = clock(); // _start = clock();
...@@ -2227,6 +2229,14 @@ public: ...@@ -2227,6 +2229,14 @@ public:
next(); next();
} }
step_count_++;
if (!done_ && (step_count_ >= spec_.config["max_steps"_])) {
PlayerId winner = lp_[0] > lp_[1] ? 0 : 1;
_duel_end(winner, 0x01);
done_ = true;
legal_actions_.clear();
}
float reward = 0; float reward = 0;
int reason = 0; int reason = 0;
if (done_) { if (done_) {
...@@ -2334,6 +2344,9 @@ public: ...@@ -2334,6 +2344,9 @@ public:
if (n_options == 0) { if (n_options == 0) {
state["info:num_options"_] = 1; state["info:num_options"_] = 1;
state["obs:global_"_][22] = uint8_t(1); state["obs:global_"_][22] = uint8_t(1);
// if (step_count_ >= spec_.config["max_steps"_]) {
// fmt::println("Max steps reached return");
// }
return; return;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment