Commit 4ad0c0c0 authored by Aarni Koskela's avatar Aarni Koskela

Verify architecture for loaded Spandrel models

parent c7561335
...@@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
return modelloader.load_spandrel_model(filename, device=device) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings(): def on_ui_settings():
......
...@@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
filename, filename,
device=self._get_device(), device=self._get_device(),
dtype=devices.dtype, dtype=devices.dtype,
expected_architecture="SwinIR",
) )
if getattr(opts, 'SWIN_torch_compile', False): if getattr(opts, 'SWIN_torch_compile', False):
try: try:
......
...@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): ...@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
model_path, model_path,
device=devices.device_codeformer, device=devices.device_codeformer,
expected_architecture='CodeFormer',
).model ).model
raise ValueError("No codeformer model found") raise ValueError("No codeformer model found")
......
...@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler): ...@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
filename, filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None), device=('cpu' if devices.device_esrgan.type == 'mps' else None),
expected_architecture='ESRGAN',
) )
......
...@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): ...@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
net = modelloader.load_spandrel_model( net = modelloader.load_spandrel_model(
model_path, model_path,
device=self.get_device(), device=self.get_device(),
expected_architecture='GFPGAN',
).model ).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net return net
......
...@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler): ...@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
path, path,
device=devices.device_esrgan, # TODO: should probably be device_hat device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
) )
...@@ -6,6 +6,8 @@ import shutil ...@@ -6,6 +6,8 @@ import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
import torch
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
...@@ -183,9 +185,18 @@ def load_upscalers(): ...@@ -183,9 +185,18 @@ def load_upscalers():
) )
def load_spandrel_model(path, *, device, half: bool = False, dtype=None): def load_spandrel_model(
path: str,
*,
device: str | torch.device | None,
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
import spandrel import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path) model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half: if half:
model = model.model.half() model = model.model.half()
if dtype: if dtype:
......
import os import os
from modules.upscaler_utils import upscale_with_model
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
...@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path, info.local_data_path,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="RealESRGAN",
) )
return upscale_with_model( return upscale_with_model(
mod, mod,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment