Commit bf67a5dc authored by Aarni Koskela's avatar Aarni Koskela

Upscaler.load_model: don't return None, just use exceptions

parent e3a973a6
...@@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler): ...@@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler):
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
try: return LDSR(model, yaml)
return LDSR(model, yaml)
except Exception:
errors.report("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path): def do_upscale(self, img, path):
ldsr = self.load_model(path) try:
if ldsr is None: ldsr = self.load_model(path)
print("NO LDSR!") except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)
......
import os.path
import sys import sys
import PIL.Image import PIL.Image
...@@ -8,7 +7,7 @@ from tqdm import tqdm ...@@ -8,7 +7,7 @@ from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
...@@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) try:
if model is None: model = self.load_model(selected_file)
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
...@@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for _, v in model.named_parameters(): for _, v in model.named_parameters():
......
import os import sys
import numpy as np import numpy as np
import torch import torch
...@@ -7,8 +7,8 @@ from tqdm import tqdm ...@@ -7,8 +7,8 @@ from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler): ...@@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img, model_file):
model = self.load_model(model_file) try:
if model is None: model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
...@@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler): ...@@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"): if filename.endswith(".v2.pth"):
model = net2( model = Swin2SR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
window_size=8, window_size=8,
img_range=1.0, img_range=1.0,
depths=[6, 6, 6, 6, 6, 6], depths=[6, 6, 6, 6, 6, 6],
embed_dim=180, embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, mlp_ratio=2,
upsampler="nearest+conv", upsampler="nearest+conv",
resi_connection="1conv", resi_connection="1conv",
) )
params = None params = None
else: else:
model = net( model = SwinIR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
......
import os import sys
import numpy as np import numpy as np
import torch import torch
...@@ -6,9 +6,8 @@ from PIL import Image ...@@ -6,9 +6,8 @@ from PIL import Image
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict): def mod2normal(state_dict):
...@@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): ...@@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data) self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
model = self.load_model(selected_model) try:
if model is None: model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
...@@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ...@@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler):
) )
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
......
...@@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts ...@@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
def __init__(self, path): def __init__(self, path):
self.name = "RealESRGAN" self.name = "RealESRGAN"
...@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable: if not self.enable:
return img return img
info = self.load_model(path) try:
if not os.path.exists(info.local_data_path): info = self.load_model(path)
print(f"Unable to load RealESRGAN model: {info.name}") except Exception:
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(
...@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image return image
def load_model(self, path): def load_model(self, path):
try: for scaler in self.scalers:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
if info is None: scaler.local_data_path = modelloader.load_file_from_url(
print(f"Unable to find model info: {path}") scaler.data_path,
return None model_dir=self.model_download_path,
)
if info.local_data_path.startswith("http"): if not os.path.exists(scaler.local_data_path):
info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
return info raise ValueError(f"Unable to find model info: {path}")
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _): def load_models(self, _):
return get_realesrgan_models(self) return get_realesrgan_models(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment