Commit eaba913c authored by novelailab's avatar novelailab

fix

parent 3c7bd057
......@@ -78,13 +78,6 @@ class SelfAttention(nn.Module):
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
......@@ -103,13 +96,16 @@ class SelfAttention(nn.Module):
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
if cache:
# doing this to avoid transposing key again after loading it as transposed.
cache = (key, )
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
......@@ -122,7 +118,7 @@ class SelfAttention(nn.Module):
x = self.out_proj(x)
if cache:
return x, (cache[0], value)
return x, (key, value)
else:
return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment