Commit b00b4294 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14559 from Nuullll/ipex-sdpa-fix

[IPEX] Fix SDPA attn_mask dtype
parents 8b6848c6 16b4d2cf
...@@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( ...@@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first # cast to same dtype first
key = key.to(query.dtype) key = key.to(query.dtype)
value = value.to(query.dtype) value = value.to(query.dtype)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length L = query.size(-2) # Target sequence length
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment