• Alexander Ljungberg's avatar
    Fix upcast attention dtype error. · d9cc0910
    Alexander Ljungberg authored
    Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error:
    
    ```
      File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
    RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead.
    ```
    
    The fix is to make sure to upcast the value tensor too.
    d9cc0910
sd_hijack_optimizations.py 22.3 KB