Commit c8732dfa authored by Louis Del Valle's avatar Louis Del Valle Committed by GitHub

Update sub_quadratic_attention.py

1. Determine the number of query chunks.
2. Calculate the final shape of the res tensor.
3. Initialize the tensor with the calculated shape and dtype, (same dtype as the input tensors, usually)

Can initialize the tensor as a zero-filled tensor with the correct shape and dtype, then compute the attention scores for each query chunk and fill the corresponding slice of tensor.
parent 8aa87c56
...@@ -202,13 +202,22 @@ def efficient_dot_product_attention( ...@@ -202,13 +202,22 @@ def efficient_dot_product_attention(
value=value, value=value,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # slices of res tensor are mutable, modifications made
# and pass slices to be mutated, instead of torch.cat()ing the returned slices # to the slices will affect the original tensor.
res = torch.cat([ # if output of compute_query_chunk_attn function has same number of
compute_query_chunk_attn( # dimensions as input query tensor, we initialize tensor like this:
num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
query_shape = get_query_chunk(0).shape
res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
res_dtype = get_query_chunk(0).dtype
res = torch.zeros(res_shape, dtype=res_dtype)
for i in range(num_query_chunks):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key=key, key=key,
value=value, value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size)) )
], dim=1) res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
return res return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment