Commit f94a7fc3 authored by sbl1996@126.com's avatar sbl1996@126.com

add option for action_feats

parent 8cebfebf
......@@ -207,6 +207,7 @@ class Encoder(nn.Module):
use_history: bool = True
card_mask: bool = False
noam: bool = False
action_feats: bool = True
version: int = 0
@nn.compact
......@@ -392,14 +393,18 @@ class Encoder(nn.Module):
a_mask = x_actions[:, :, 3] == 0
a_mask = a_mask.at[:, 0].set(False)
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats = [f_g_card, f_global]
if self.use_history:
f_state = jnp.concatenate([f_g_card, f_global, f_g_h_actions, f_g_actions], axis=-1)
else:
f_state = jnp.concatenate([f_g_card, f_global, f_g_actions], axis=-1)
g_feats.append(f_g_h_actions)
if self.action_feats:
f_actions_g = fc_layer(c, dtype=self.dtype)(f_actions)
a_mask_ = (1 - a_mask.astype(f_actions.dtype))
f_g_actions = (f_actions_g * a_mask_[:, :, None]).sum(axis=1)
f_g_actions = f_g_actions / a_mask_.sum(axis=1, keepdims=True)
g_feats.append(f_g_actions)
f_state = jnp.concatenate(g_feats, axis=-1)
oc = self.out_channels or c
f_state = MLP((c * 2, oc), dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
f_state = layer_norm(dtype=self.dtype)(f_state)
......@@ -516,6 +521,8 @@ class ModelArgs:
"""whether to use Noam architecture for the transformer layer"""
rwkv_head_size: int = 32
"""the head size for the RWKV"""
action_feats: bool = True
"""whether to use action features for the global state"""
version: int = 0
"""the version of the environment and the agent"""
......@@ -535,6 +542,7 @@ class RNNAgent(nn.Module):
film: bool = False
noam: bool = False
rwkv_head_size: int = 32
action_feats: bool = True
version: int = 0
@nn.compact
......@@ -552,6 +560,7 @@ class RNNAgent(nn.Module):
use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
action_feats=self.action_feats,
version=self.version,
)
......
......@@ -60,7 +60,6 @@ class GLUMlp(nn.Module):
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
last_kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
......@@ -73,7 +72,6 @@ class GLUMlp(nn.Module):
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
) for _ in range(3)
]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment