Commit 746783f7 authored by Nuullll's avatar Nuullll

[IPEX] Fix embedding

Cast `torch.bmm` args into same `dtype`.

Fixes the following error when using Text Inversion embedding (#14224):

```
RuntimeError: could not create a primitive descriptor for a matmul
primitive
```
parent f92d6149
...@@ -48,3 +48,6 @@ if has_xpu: ...@@ -48,3 +48,6 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward', CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment