Commit ba66cf8d authored by wangshuai09's avatar wangshuai09

update

parent b7aa4253
......@@ -95,6 +95,7 @@ class HypernetworkModule(torch.nn.Module):
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
devices.torch_npu_set_device()
self.to(devices.device)
def fix_old_state_dict(self, state_dict):
......
......@@ -230,7 +230,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
for fixes in self.hijack.fixes:
for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding
devices.torch_npu_set_device()
z = self.process_tokens(tokens, multipliers)
zs.append(z)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment