Commit fa2cfe52 authored by novelailab's avatar novelailab

fix attn more

parent 39568281
...@@ -78,9 +78,9 @@ class SelfAttention(nn.Module): ...@@ -78,9 +78,9 @@ class SelfAttention(nn.Module):
def forward(self, x, kv=None): def forward(self, x, kv=None):
B, S, H = x.shape # batch, sequence, hidden_dim B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim] # split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, self.n_head, S, self.head_dim) query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, self.n_head, S, self.head_dim) key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, self.n_head, S, self.head_dim) value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv: if kv:
k, v = kv k, v = kv
...@@ -89,17 +89,14 @@ class SelfAttention(nn.Module): ...@@ -89,17 +89,14 @@ class SelfAttention(nn.Module):
torch.cat([k, key], dim=-2) # cat key torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value torch.cat([v, value], dim=-2) # cat value
key = key.permute(0, 2, 1, 3) query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
query = query.permute(0, 2, 1, 3) causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn( x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
) )
x = x.contiguous().view(B, S, H) x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x) x = self.out_proj(x)
return x return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment