Commit f9c5da15 authored by AUTOMATIC's avatar AUTOMATIC

add fallback for xformers_attnblock_forward

parent a5550f02
...@@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x): ...@@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x):
return h3 return h3
def xformers_attnblock_forward(self, x): def xformers_attnblock_forward(self, x):
try:
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
q1 = self.q(h_).contiguous() q1 = self.q(h_).contiguous()
...@@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x): ...@@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_).contiguous() v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v) out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment