Commit 9a3f35b0 authored by AUTOMATIC1111's avatar AUTOMATIC1111

repair medvram and lowvram

parent abb948da
...@@ -100,6 +100,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -100,6 +100,8 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder: if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if hasattr(sd_model, 'cond_stage_model'):
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
......
...@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit ...@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text) ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int) ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded
...@@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi ...@@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text) ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int) ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment