Commit 4f79dd7a authored by kurumuz's avatar kurumuz

revert

parent 81a39ebb
...@@ -188,16 +188,19 @@ class StableDiffusionModel(nn.Module): ...@@ -188,16 +188,19 @@ class StableDiffusionModel(nn.Module):
typex = torch.float32 typex = torch.float32
self.model = model.to(config.device).to(typex) self.model = model.to(config.device).to(typex)
if self.config.vae_path: if self.config.vae_path:
model.first_stage_model = model.first_stage_model.float()
ckpt=torch.load(self.config.vae_path, map_location="cpu") ckpt=torch.load(self.config.vae_path, map_location="cpu")
dec_ckpt = {} loss = []
for i in ckpt["state_dict"].keys(): for i in ckpt["state_dict"].keys():
if i[0:8] == "decoder.": if i[0:4] == "loss":
dec_ckpt[i[8:]] = ckpt["state_dict"][i] loss.append(i)
x, y = model.first_stage_model.decoder.load_state_dict(dec_ckpt) for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float() model.first_stage_model = model.first_stage_model.float()
del ckpt del ckpt
del dec_ckpt del loss
config.logger.info(f"Using VAE from {self.config.vae_path}") config.logger.info(f"Using VAE from {self.config.vae_path}")
if self.config.penultimate == "1": if self.config.penultimate == "1":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment