Commit ac4ccfa1 authored by AUTOMATIC1111's avatar AUTOMATIC1111

get attention optimizations to work

parent b717eb7e
...@@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): ...@@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
return context_k, context_v return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
......
...@@ -239,6 +239,7 @@ def mute_sdxl_imports(): ...@@ -239,6 +239,7 @@ def mute_sdxl_imports():
sys.modules['sgm.data'] = module sys.modules['sgm.data'] = module
def prepare_environment(): def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
......
...@@ -173,7 +173,7 @@ def get_available_vram(): ...@@ -173,7 +173,7 @@ def get_available_vram():
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
...@@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona ...@@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona
# taken from https://github.com/Doggettx/stable-diffusion and modified # taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
...@@ -355,7 +355,7 @@ def einsum_op(q, k, v): ...@@ -355,7 +355,7 @@ def einsum_op(q, k, v):
return einsum_op_tensor_mem(q, k, v, 32) return einsum_op_tensor_mem(q, k, v, 32)
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
...@@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add ...@@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads h = self.heads
...@@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): ...@@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
return None return None
def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
...@@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke ...@@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape batch_size, sequence_length, inner_dim = x.shape
if mask is not None: if mask is not None:
...@@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit ...@@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit
return hidden_states return hidden_states
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask) return scaled_dot_product_attention_forward(self, x, context, mask)
......
...@@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None ...@@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
sgm.modules.encoders.modules.print = lambda *args: None sgm.modules.encoders.modules.print = lambda *args: None
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment