Commit 4f79dd7a authored by kurumuz's avatar kurumuz

revert

parent 81a39ebb
......@@ -188,16 +188,19 @@ class StableDiffusionModel(nn.Module):
typex = torch.float32
self.model = model.to(config.device).to(typex)
if self.config.vae_path:
model.first_stage_model = model.first_stage_model.float()
ckpt=torch.load(self.config.vae_path, map_location="cpu")
dec_ckpt = {}
loss = []
for i in ckpt["state_dict"].keys():
if i[0:8] == "decoder.":
dec_ckpt[i[8:]] = ckpt["state_dict"][i]
x, y = model.first_stage_model.decoder.load_state_dict(dec_ckpt)
if i[0:4] == "loss":
loss.append(i)
for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float()
del ckpt
del dec_ckpt
del loss
config.logger.info(f"Using VAE from {self.config.vae_path}")
if self.config.penultimate == "1":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment