Commit 24e93cbf authored by novelailab's avatar novelailab

if cache

parent fa2cfe52
...@@ -75,7 +75,7 @@ class SelfAttention(nn.Module): ...@@ -75,7 +75,7 @@ class SelfAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None): def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim] # split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2) query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
...@@ -98,7 +98,9 @@ class SelfAttention(nn.Module): ...@@ -98,7 +98,9 @@ class SelfAttention(nn.Module):
x = x.transpose(1, 2).contiguous().view(B, S, H) x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x) x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x return x
class FeedForward(nn.Module): class FeedForward(nn.Module):
......
...@@ -30,34 +30,6 @@ def apply_rotary_pos_emb(x, sincos, offset=0): ...@@ -30,34 +30,6 @@ def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos) sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin) return (x * cos) + (rotate_every_two(x) * sin)
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias, def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None): attention_mask=None, scale_attn=None):
...@@ -99,13 +71,12 @@ class SelfAttention(nn.Module): ...@@ -99,13 +71,12 @@ class SelfAttention(nn.Module):
self.register_buffer("cos", cos) self.register_buffer("cos", cos)
def forward(self, x): def forward(self, x):
query = self.q_proj(x) B, S, H = x.shape # batch, sequence, hidden_dim
key = self.k_proj(x) # split heads into: [batch, head, sequence, head_dim]
value = self.v_proj(x) # other than v because some rotary bs?
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
query = _split_heads(query, self.n_head, self.head_dim, True) key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = _split_heads(key, self.n_head, self.head_dim, True) value = self.v_proj(x).view(B, S, self.n_head, self.head_dim)
value = _split_heads(value, self.n_head, self.head_dim, False)
offset = 0 offset = 0
if self.rotary_dim < self.head_dim: if self.rotary_dim < self.head_dim:
...@@ -125,17 +96,16 @@ class SelfAttention(nn.Module): ...@@ -125,17 +96,16 @@ class SelfAttention(nn.Module):
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype) key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype) query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] #causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn( x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
) )
x = _merge_heads(x, self.n_head, self.head_dim) x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x) x = self.out_proj(x)
return x return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment