Commit 5b242a97 authored by novelailab's avatar novelailab

Autodetect checkpoint type

parent 4fe25b9c
...@@ -279,7 +279,7 @@ class StableDiffusionModel(nn.Module): ...@@ -279,7 +279,7 @@ class StableDiffusionModel(nn.Module):
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if self.config.basedformer == "1": if self.config.basedformer == "1" or 'state_dict' in pl_sd:
sd = pl_sd sd = pl_sd
else: else:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment