Commit f9c5da15 authored by AUTOMATIC's avatar AUTOMATIC

add fallback for xformers_attnblock_forward

parent a5550f02
......@@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x):
return h3
def xformers_attnblock_forward(self, x):
try:
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_).contiguous()
......@@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out)
return x+out
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment