Commit 87dd6852 authored by brkirch's avatar brkirch

Make sub-quadratic the default for MPS

parent abfa4ad8
...@@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): ...@@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization): class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic" name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention" cmd_opt = "opt_sub_quad_attention"
priority = 10
@property
def priority(self):
return 1000 if shared.device.type == 'mps' else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
...@@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): ...@@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property @property
def priority(self): def priority(self):
return 1000 if not torch.cuda.is_available() else 10 return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment