Commit 6752ba72 authored by sbl1996@126.com's avatar sbl1996@126.com

Add card_mask and film to agent

parent 1d35fed3
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import os import os
import random import random
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass, field, asdict
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
...@@ -18,7 +18,7 @@ import flax ...@@ -18,7 +18,7 @@ import flax
from ygoai.utils import init_ygopro from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.jax.agent2 import RNNAgent from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
@dataclass @dataclass
...@@ -44,10 +44,6 @@ class Args: ...@@ -44,10 +44,6 @@ class Args:
"""the number of history actions to use""" """the number of history actions to use"""
num_embeddings: Optional[int] = None num_embeddings: Optional[int] = None
"""the number of embeddings of the agent""" """the number of embeddings of the agent"""
use_history1: bool = True
"""whether to use history actions as input for agent1"""
use_history2: bool = True
"""whether to use history actions as input for agent2"""
verbose: bool = False verbose: bool = False
"""whether to print debug information""" """whether to print debug information"""
...@@ -59,16 +55,11 @@ class Args: ...@@ -59,16 +55,11 @@ class Args:
num_envs: int = 64 num_envs: int = 64
"""the number of parallel game environments""" """the number of parallel game environments"""
num_layers: int = 2 m1: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of layers for the agent""" """the model arguments for the agent1"""
num_channels: int = 128 m2: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of channels for the agent""" """the model arguments for the agent2"""
rnn_channels: Optional[int] = 512
"""the number of rnn channels for the agent"""
rnn_type1: Optional[str] = "lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2: Optional[str] = "lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1: str = "checkpoints/agent.pt" checkpoint1: str = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file""" """the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2: str = "checkpoints/agent.pt" checkpoint2: str = "checkpoints/agent.pt"
...@@ -83,23 +74,15 @@ class Args: ...@@ -83,23 +74,15 @@ class Args:
def create_agent1(args): def create_agent1(args):
return RNNAgent( return RNNAgent(
channels=args.num_channels, **asdict(args.m1),
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
use_history=args.use_history1,
rnn_type=args.rnn_type1,
) )
def create_agent2(args): def create_agent2(args):
return RNNAgent( return RNNAgent(
channels=args.num_channels, **asdict(args.m2),
num_layers=args.num_layers,
rnn_channels=args.rnn_channels,
embedding_shape=args.num_embeddings, embedding_shape=args.num_embeddings,
use_history=args.use_history2,
rnn_type=args.rnn_type2,
) )
......
...@@ -6,7 +6,7 @@ import threading ...@@ -6,7 +6,7 @@ import threading
import time import time
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field, asdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
from functools import partial from functools import partial
...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter ...@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from ygoai.utils import init_ygopro, load_embeddings from ygoai.utils import init_ygopro, load_embeddings
from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files from ygoai.rl.ckpt import ModelCheckpoint, sync_to_gcs, zip_files
from ygoai.rl.jax.agent2 import RNNAgent from ygoai.rl.jax.agent2 import RNNAgent, ModelArgs
from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_normalize, categorical_sample
from ygoai.rl.jax.eval import evaluate, battle from ygoai.rl.jax.eval import evaluate, battle
from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss from ygoai.rl.jax import clipped_surrogate_pg_loss, vtrace_2p0s, mse_loss, entropy_loss, simple_policy_loss, ach_loss, policy_gradient_loss
...@@ -80,10 +80,6 @@ class Args: ...@@ -80,10 +80,6 @@ class Args:
"""the number of history actions to use""" """the number of history actions to use"""
greedy_reward: bool = False greedy_reward: bool = False
"""whether to use greedy reward (faster kill higher reward)""" """whether to use greedy reward (faster kill higher reward)"""
use_history: bool = True
"""whether to use history actions as input for agent"""
eval_use_history: bool = True
"""whether to use history actions as input for eval agent"""
total_timesteps: int = 50000000000 total_timesteps: int = 50000000000
"""total timesteps of the experiments""" """total timesteps of the experiments"""
...@@ -146,16 +142,10 @@ class Args: ...@@ -146,16 +142,10 @@ class Args:
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
"""the maximum norm for the gradient clipping""" """the maximum norm for the gradient clipping"""
num_layers: int = 2 m1: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of layers for the agent""" """the model arguments for the agent"""
num_channels: int = 128 m2: ModelArgs = field(default_factory=lambda: ModelArgs())
"""the number of channels for the agent""" """the model arguments for the eval agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
rnn_type: Optional[str] = "lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type: Optional[str] = "lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids: List[int] = field(default_factory=lambda: [0, 1]) actor_device_ids: List[int] = field(default_factory=lambda: [0, 1])
"""the device ids that actor workers will use""" """the device ids that actor workers will use"""
...@@ -228,18 +218,22 @@ class Transition(NamedTuple): ...@@ -228,18 +218,22 @@ class Transition(NamedTuple):
def create_agent(args, eval=False): def create_agent(args, eval=False):
return RNNAgent( if eval:
channels=args.num_channels, return RNNAgent(
num_layers=args.num_layers, embedding_shape=args.num_embeddings,
embedding_shape=args.num_embeddings, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32, param_dtype=jnp.float32,
param_dtype=jnp.float32, **asdict(args.m2),
rnn_channels=args.rnn_channels, )
switch=args.switch, else:
freeze_id=args.freeze_id, return RNNAgent(
use_history=args.use_history if not eval else args.eval_use_history, embedding_shape=args.num_embeddings,
rnn_type=args.rnn_type if not eval else args.eval_rnn_type, dtype=jnp.bfloat16 if args.bfloat16 else jnp.float32,
) param_dtype=jnp.float32,
switch=args.switch,
freeze_id=args.freeze_id,
**asdict(args.m1),
)
def init_rnn_state(num_envs, rnn_channels): def init_rnn_state(num_envs, rnn_channels):
......
from typing import Tuple, Union, Optional, Sequence from dataclasses import dataclass
from typing import Tuple, Union, Optional, Sequence, Literal
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -6,7 +7,7 @@ import jax ...@@ -6,7 +7,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.linen as nn import flax.linen as nn
from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding from ygoai.rl.jax.transformer import EncoderLayer, PositionalEncoding, LlamaEncoderLayer
from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id from ygoai.rl.jax.modules import MLP, make_bin_params, bytes_to_bin, decode_id
...@@ -15,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001) ...@@ -15,6 +16,13 @@ default_fc_init1 = nn.initializers.uniform(scale=0.001)
default_fc_init2 = nn.initializers.uniform(scale=0.001) default_fc_init2 = nn.initializers.uniform(scale=0.001)
def get_encoder_layer_cls(noam, n_heads, dtype, param_dtype):
if noam:
return LlamaEncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype, rope=False)
else:
return EncoderLayer(n_heads, dtype=dtype, param_dtype=param_dtype)
class ActionEncoder(nn.Module): class ActionEncoder(nn.Module):
channels: int = 128 channels: int = 128
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -69,6 +77,8 @@ class CardEncoder(nn.Module): ...@@ -69,6 +77,8 @@ class CardEncoder(nn.Module):
x_id = layer_norm()(x_id) x_id = layer_norm()(x_id)
x_loc = x1[:, :, 0] x_loc = x1[:, :, 0]
c_mask = x_loc == 0
c_mask = c_mask.at[:, 0].set(False)
f_loc = layer_norm()(embed(9, c)(x_loc)) f_loc = layer_norm()(embed(9, c)(x_loc))
x_seq = x1[:, :, 1] x_seq = x1[:, :, 1]
...@@ -97,7 +107,7 @@ class CardEncoder(nn.Module): ...@@ -97,7 +107,7 @@ class CardEncoder(nn.Module):
f_cards = jnp.concatenate([x_id, x_f], axis=-1) f_cards = jnp.concatenate([x_id, x_f], axis=-1)
f_cards = f_cards + f_loc + f_seq f_cards = f_cards + f_loc + f_seq
return f_cards return f_cards, c_mask
class GlobalEncoder(nn.Module): class GlobalEncoder(nn.Module):
...@@ -153,6 +163,8 @@ class Encoder(nn.Module): ...@@ -153,6 +163,8 @@ class Encoder(nn.Module):
param_dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True use_history: bool = True
card_mask: bool = False
noam: bool = False
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -188,7 +200,7 @@ class Encoder(nn.Module): ...@@ -188,7 +200,7 @@ class Encoder(nn.Module):
x_id = jax.lax.stop_gradient(x_id) x_id = jax.lax.stop_gradient(x_id)
# Cards # Cards
f_cards = CardEncoder( f_cards, c_mask = CardEncoder(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:]) channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)(x_id, x_cards[:, :, 2:])
g_card_embed = self.param( g_card_embed = self.param(
'g_card_embed', 'g_card_embed',
...@@ -196,10 +208,16 @@ class Encoder(nn.Module): ...@@ -196,10 +208,16 @@ class Encoder(nn.Module):
(1, c), self.param_dtype) (1, c), self.param_dtype)
f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype) f_g_card = jnp.tile(g_card_embed, (batch_size, 1, 1)).astype(f_cards.dtype)
f_cards = jnp.concatenate([f_g_card, f_cards], axis=1) f_cards = jnp.concatenate([f_g_card, f_cards], axis=1)
if self.card_mask:
c_mask = jnp.concatenate([jnp.zeros((batch_size, 1), dtype=c_mask.dtype), c_mask], axis=1)
else:
c_mask = None
num_heads = max(2, c // 128) num_heads = max(2, c // 128)
for _ in range(self.num_layers): for _ in range(self.num_layers):
f_cards = EncoderLayer(num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(f_cards) f_cards = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_cards, src_key_padding_mask=c_mask)
f_cards = layer_norm(dtype=self.dtype)(f_cards) f_cards = layer_norm(dtype=self.dtype)(f_cards)
f_g_card = f_cards[:, 0] f_g_card = f_cards[:, 0]
...@@ -294,6 +312,32 @@ class Actor(nn.Module): ...@@ -294,6 +312,32 @@ class Actor(nn.Module):
return logits return logits
class FiLMActor(nn.Module):
channels: int = 128
dtype: Optional[jnp.dtype] = None
param_dtype: jnp.dtype = jnp.float32
noam: bool = False
@nn.compact
def __call__(self, f_state, f_actions, mask):
f_state = f_state.astype(self.dtype)
f_actions = f_actions.astype(self.dtype)
c = self.channels
t = nn.Dense(c * 4, dtype=self.dtype, param_dtype=self.param_dtype)(f_state)
a_s, a_b, o_s, o_b = jnp.split(t[:, None, :], 4, axis=-1)
num_heads = max(2, c // 128)
f_actions = get_encoder_layer_cls(
self.noam, num_heads, dtype=self.dtype, param_dtype=self.param_dtype)(
f_actions, mask, a_s, a_b, o_s, o_b)
logits = nn.Dense(1, dtype=jnp.float32, param_dtype=self.param_dtype,
kernel_init=nn.initializers.orthogonal(0.01))(f_actions)[:, :, 0]
big_neg = jnp.finfo(logits.dtype).min
logits = jnp.where(mask, big_neg, logits)
return logits
class Critic(nn.Module): class Critic(nn.Module):
channels: Sequence[int] = (128, 128, 128) channels: Sequence[int] = (128, 128, 128)
dtype: Optional[jnp.dtype] = None dtype: Optional[jnp.dtype] = None
...@@ -340,10 +384,29 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True ...@@ -340,10 +384,29 @@ def rnn_forward_2p(rnn_layer, rstate, f_state, done, switch_or_main, switch=True
return rstate, f_state return rstate, f_state
@dataclass
class ModelArgs:
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
rnn_channels: int = 512
"""the number of channels for the RNN in the agent"""
use_history: bool = True
"""whether to use history actions as input for agent"""
card_mask: bool = False
"""whether to mask the padding card as ignored in the transformer"""
rnn_type: Optional[Literal['lstm', 'gru', 'none']] = "lstm"
"""the type of RNN to use, None for no RNN"""
film: bool = False
"""whether to use FiLM for the actor"""
noam: bool = False
"""whether to use Noam architecture for the transformer layer"""
class RNNAgent(nn.Module): class RNNAgent(nn.Module):
channels: int = 128
num_layers: int = 2 num_layers: int = 2
num_channels: int = 128
rnn_channels: int = 512 rnn_channels: int = 512
embedding_shape: Optional[Union[int, Tuple[int, int]]] = None embedding_shape: Optional[Union[int, Tuple[int, int]]] = None
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -351,11 +414,14 @@ class RNNAgent(nn.Module): ...@@ -351,11 +414,14 @@ class RNNAgent(nn.Module):
switch: bool = True switch: bool = True
freeze_id: bool = False freeze_id: bool = False
use_history: bool = True use_history: bool = True
card_mask: bool = False
rnn_type: str = 'lstm' rnn_type: str = 'lstm'
film: bool = False
noam: bool = False
@nn.compact @nn.compact
def __call__(self, x, rstate, done=None, switch_or_main=None): def __call__(self, x, rstate, done=None, switch_or_main=None):
c = self.channels c = self.num_channels
encoder = Encoder( encoder = Encoder(
channels=c, channels=c,
num_layers=self.num_layers, num_layers=self.num_layers,
...@@ -364,6 +430,8 @@ class RNNAgent(nn.Module): ...@@ -364,6 +430,8 @@ class RNNAgent(nn.Module):
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
freeze_id=self.freeze_id, freeze_id=self.freeze_id,
use_history=self.use_history, use_history=self.use_history,
card_mask=self.card_mask,
noam=self.noam,
) )
f_actions, f_state, mask, valid = encoder(x) f_actions, f_state, mask, valid = encoder(x)
...@@ -401,8 +469,12 @@ class RNNAgent(nn.Module): ...@@ -401,8 +469,12 @@ class RNNAgent(nn.Module):
rstate, f_state_r = rnn_step_by_main( rstate, f_state_r = rnn_step_by_main(
rnn_layer, rstate, f_state, done, switch_or_main) rnn_layer, rstate, f_state, done, switch_or_main)
actor = Actor( if self.film:
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype) actor = FiLMActor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype, noam=self.noam)
else:
actor = Actor(
channels=c, dtype=jnp.float32, param_dtype=self.param_dtype)
critic = Critic( critic = Critic(
channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype) channels=[c, c, c], dtype=self.dtype, param_dtype=self.param_dtype)
......
...@@ -596,7 +596,6 @@ class GLUMlpBlock(nn.Module): ...@@ -596,7 +596,6 @@ class GLUMlpBlock(nn.Module):
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
bias_init=self.bias_init, bias_init=self.bias_init,
shard=self.shard,
) for _ in range(3) ) for _ in range(3)
] ]
...@@ -631,7 +630,10 @@ class EncoderLayer(nn.Module): ...@@ -631,7 +630,10 @@ class EncoderLayer(nn.Module):
deterministic: bool = True deterministic: bool = True
@nn.compact @nn.compact
def __call__(self, inputs, src_key_padding_mask=None): def __call__(
self, inputs, src_key_padding_mask=None,
attn_scale=None, attn_bias=None,
output_scale=None, output_bias=None):
inputs = jnp.asarray(inputs, self.dtype) inputs = jnp.asarray(inputs, self.dtype)
x = nn.LayerNorm(epsilon=self.layer_norm_epsilon, x = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
dtype=self.dtype, name="ln_1")(inputs) dtype=self.dtype, name="ln_1")(inputs)
...@@ -648,6 +650,11 @@ class EncoderLayer(nn.Module): ...@@ -648,6 +650,11 @@ class EncoderLayer(nn.Module):
x = nn.Dropout(rate=self.resid_pdrop)( x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic) x, deterministic=self.deterministic)
if attn_scale is not None:
x = x * attn_scale
if attn_bias is not None:
x = x + attn_bias
x = x + inputs x = x + inputs
y = nn.LayerNorm(epsilon=self.layer_norm_epsilon, y = nn.LayerNorm(epsilon=self.layer_norm_epsilon,
...@@ -662,7 +669,13 @@ class EncoderLayer(nn.Module): ...@@ -662,7 +669,13 @@ class EncoderLayer(nn.Module):
name="mlp")(y) name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)( y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic) y, deterministic=self.deterministic)
if output_scale is not None:
y = y * output_scale
if output_bias is not None:
y = y + output_bias
y = x + y y = x + y
return y return y
...@@ -733,8 +746,9 @@ class DecoderLayer(nn.Module): ...@@ -733,8 +746,9 @@ class DecoderLayer(nn.Module):
class LlamaEncoderLayer(nn.Module): class LlamaEncoderLayer(nn.Module):
n_heads: int n_heads: int
intermediate_size: int intermediate_size: Optional[int] = None
n_positions: int = 512 n_positions: int = 512
rope: bool = True
dtype: Any = None dtype: Any = None
param_dtype: Any = jnp.float32 param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0 attn_pdrop: float = 0.0
...@@ -745,11 +759,17 @@ class LlamaEncoderLayer(nn.Module): ...@@ -745,11 +759,17 @@ class LlamaEncoderLayer(nn.Module):
deterministic: bool = True deterministic: bool = True
@nn.compact @nn.compact
def __call__(self, inputs, src_key_padding_mask=None): def __call__(
self, inputs, src_key_padding_mask=None,
attn_scale=None, attn_bias=None,
output_scale=None, output_bias=None):
features = inputs.shape[-1]
intermediate_size = self.intermediate_size or 2 * features
x = RMSNorm(epsilon=self.rms_norm_eps, x = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_1")(inputs) dtype=self.dtype, name="ln_1")(inputs)
x = MultiheadAttention( x = MultiheadAttention(
features=x.shape[-1], features=features,
num_heads=self.n_heads, num_heads=self.n_heads,
max_len=self.n_positions, max_len=self.n_positions,
dtype=self.dtype, dtype=self.dtype,
...@@ -757,19 +777,24 @@ class LlamaEncoderLayer(nn.Module): ...@@ -757,19 +777,24 @@ class LlamaEncoderLayer(nn.Module):
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
qkv_bias=False, qkv_bias=False,
out_bias=False, out_bias=False,
rope=True, rope=self.rope,
dropout_rate=self.attn_pdrop, dropout_rate=self.attn_pdrop,
deterministic=self.deterministic, deterministic=self.deterministic,
name="attn")(x, x, x, key_padding_mask=src_key_padding_mask) name="attn")(x, x, x, key_padding_mask=src_key_padding_mask)
x = nn.Dropout(rate=self.resid_pdrop)( x = nn.Dropout(rate=self.resid_pdrop)(
x, deterministic=self.deterministic) x, deterministic=self.deterministic)
if attn_scale is not None:
x = x * attn_scale
if attn_bias is not None:
x = x + attn_bias
x = x + inputs x = x + inputs
y = RMSNorm(epsilon=self.rms_norm_eps, y = RMSNorm(epsilon=self.rms_norm_eps,
dtype=self.dtype, name="ln_2")(x) dtype=self.dtype, name="ln_2")(x)
y = GLUMlpBlock( y = GLUMlpBlock(
intermediate_size=self.intermediate_size, intermediate_size=intermediate_size,
dtype=self.dtype, dtype=self.dtype,
param_dtype=self.param_dtype, param_dtype=self.param_dtype,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
...@@ -777,6 +802,12 @@ class LlamaEncoderLayer(nn.Module): ...@@ -777,6 +802,12 @@ class LlamaEncoderLayer(nn.Module):
name="mlp")(y) name="mlp")(y)
y = nn.Dropout(rate=self.resid_pdrop)( y = nn.Dropout(rate=self.resid_pdrop)(
y, deterministic=self.deterministic) y, deterministic=self.deterministic)
if output_scale is not None:
y = y * output_scale
if output_bias is not None:
y = y + output_bias
y = x + y y = x + y
return y return y
...@@ -785,6 +816,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -785,6 +816,7 @@ class LlamaDecoderLayer(nn.Module):
n_heads: int n_heads: int
intermediate_size: int intermediate_size: int
n_positions: int = 512 n_positions: int = 512
rope: bool = True
dtype: Any = None dtype: Any = None
param_dtype: Any = jnp.float32 param_dtype: Any = jnp.float32
attn_pdrop: float = 0.0 attn_pdrop: float = 0.0
...@@ -808,7 +840,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -808,7 +840,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
qkv_bias=False, qkv_bias=False,
out_bias=False, out_bias=False,
rope=True, rope=self.rope,
dropout_rate=self.attn_pdrop, dropout_rate=self.attn_pdrop,
deterministic=self.deterministic, deterministic=self.deterministic,
name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask) name="self_attn")(x, x, x, key_padding_mask=tgt_key_padding_mask)
...@@ -827,7 +859,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -827,7 +859,7 @@ class LlamaDecoderLayer(nn.Module):
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
qkv_bias=False, qkv_bias=False,
out_bias=False, out_bias=False,
rope=True, rope=self.rope,
dropout_rate=self.attn_pdrop, dropout_rate=self.attn_pdrop,
deterministic=self.deterministic, deterministic=self.deterministic,
name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask) name="cross_attn")(y, memory, memory, key_padding_mask=memory_key_padding_mask)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment