Commit b0f59342 authored by Aarni Koskela's avatar Aarni Koskela

Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR)

parent e472383a
...@@ -7,9 +7,7 @@ from tqdm import tqdm ...@@ -7,9 +7,7 @@ from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
......
This diff is collapsed.
import logging
import sys import sys
import platform
import numpy as np import numpy as np
import torch import torch
...@@ -8,13 +8,11 @@ from tqdm import tqdm ...@@ -8,13 +8,11 @@ from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile) current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: device = self._get_device()
if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype)
if use_compile:
model = torch.compile(model)
self._cached_model = model self._cached_model = model
self._cached_model_config = current_config self._cached_model_config = current_config
img = upscale(img, model)
img = upscale(
img,
model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
device=device,
)
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: dtype=devices.dtype,
model.load_state_dict(pretrained_model, strict=True) )
if getattr(opts, 'SWIN_torch_compile', False):
try:
model = torch.compile(model)
except Exception:
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model return model
def _get_device(self):
return devices.get_device_for('swinir')
def upscale( def upscale(
img, img,
model, model,
tile=None, *,
tile_overlap=None, tile: int,
tile_overlap: int,
window_size=8, window_size=8,
scale=4, scale=4,
device,
): ):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale) output = inference(
img,
model,
tile=tile,
tile_overlap=tile_overlap,
window_size=window_size,
scale=scale,
device=device,
)
output = output[..., : h_old * scale, : w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3: if output.ndim == 3:
...@@ -142,7 +128,16 @@ def upscale( ...@@ -142,7 +128,16 @@ def upscale(
return Image.fromarray(output, "RGB") return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale): def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile # test the image tile by tile
b, c, h, w = img.size() b, c, h, w = img.size()
tile = min(tile, h, w) tile = min(tile, h, w)
...@@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
...@@ -185,7 +180,6 @@ def on_ui_settings(): ...@@ -185,7 +180,6 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -8,9 +8,6 @@ import modules.shared ...@@ -8,9 +8,6 @@ import modules.shared
from modules import shared, devices, modelloader, errors from modules import shared, devices, modelloader, errors
from modules.paths import models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer" model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
...@@ -18,23 +15,7 @@ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codef ...@@ -18,23 +15,7 @@ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codef
codeformer = None codeformer = None
def setup_model(dirname): class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
os.makedirs(model_path, exist_ok=True)
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
def name(self): def name(self):
return "CodeFormer" return "CodeFormer"
...@@ -44,36 +25,51 @@ def setup_model(dirname): ...@@ -44,36 +25,51 @@ def setup_model(dirname):
self.cmd_dir = dirname self.cmd_dir = dirname
def create_models(self): def create_models(self):
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) model_paths = modelloader.load_models(
model_path,
model_url,
self.cmd_dir,
download_name='codeformer-v0.1.0.pth',
ext_filter=['.pth'],
)
if len(model_paths) != 0: if len(model_paths) != 0:
ckpt_path = model_paths[0] ckpt_path = model_paths[0]
else: else:
print("Unable to load codeformer model.") print("Unable to load codeformer model.")
return None, None return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
if hasattr(retinaface, 'device'): if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=devices.device_codeformer,
)
self.net = net self.net = net
self.face_helper = face_helper self.face_helper = face_helper
return net, face_helper
def send_model_to(self, device): def send_model_to(self, device):
self.net.to(device) self.net.to(device)
self.face_helper.face_det.to(device) self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device) self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None): def restore(self, np_image, w=None):
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor, tensor2img
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2] original_resolution = np_image.shape[0:2]
...@@ -96,7 +92,13 @@ def setup_model(dirname): ...@@ -96,7 +92,13 @@ def setup_model(dirname):
try: try:
with torch.no_grad(): with torch.no_grad():
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
if isinstance(res, tuple):
output = res[0]
else:
output = res
if not isinstance(res, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(res)}")
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
devices.torch_gc() devices.torch_gc()
...@@ -113,7 +115,13 @@ def setup_model(dirname): ...@@ -113,7 +115,13 @@ def setup_model(dirname):
restored_img = restored_img[:, :, ::-1] restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]: if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) restored_img = cv2.resize(
restored_img,
(0, 0),
fx=original_resolution[1]/restored_img.shape[1],
fy=original_resolution[0]/restored_img.shape[0],
interpolation=cv2.INTER_LINEAR,
)
self.face_helper.clean_all() self.face_helper.clean_all()
...@@ -122,11 +130,12 @@ def setup_model(dirname): ...@@ -122,11 +130,12 @@ def setup_model(dirname):
return restored_img return restored_img
def setup_model(dirname):
os.makedirs(model_path, exist_ok=True)
try:
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
errors.report("Error setting up CodeFormer", exc_info=True) errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
import sys from modules import modelloader, devices, errors
import torch
import modules.esrgan_model_arch as arch
from modules import modelloader, devices
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
crt_net['model.3.weight'] = state_dict['upconv1.weight']
crt_net['model.3.bias'] = state_dict['upconv1.bias']
crt_net['model.6.weight'] = state_dict['upconv2.weight']
crt_net['model.6.bias'] = state_dict['upconv2.bias']
crt_net['model.8.weight'] = state_dict['HRconv.weight']
crt_net['model.8.bias'] = state_dict['HRconv.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if "rdb" in k:
ori_k = k.replace('body.', 'model.1.sub.')
ori_k = ori_k.replace('.rdb', '.RDB')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if (part_num > scalemin
and parts[0] == "model"
and parts[2] == "weight"):
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2 ** scale2x
return in_nc, out_nc, nf, nb, plus, scale
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ESRGAN" self.name = "ESRGAN"
...@@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler): ...@@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
try: try:
model = self.load_model(selected_model) model = self.load_model(selected_model)
except Exception as e: except Exception:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) return esrgan_upscale(model, img)
return img
def load_model(self, path: str): def load_model(self, path: str):
if path.startswith("http"): if path.startswith("http"):
...@@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler): ...@@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler):
else: else:
filename = path filename = path
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) return modelloader.load_spandrel_model(
filename,
if "params_ema" in state_dict: device=('cpu' if devices.device_esrgan.type == 'mps' else None),
state_dict = state_dict["params_ema"] )
elif "params" in state_dict:
state_dict = state_dict["params"]
num_conv = 16 if "realesr-animevideov3" in filename else 32
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
model.load_state_dict(state_dict)
model.eval()
return model
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
state_dict = resrgan2normal(state_dict, nb)
elif "conv_first.weight" in state_dict:
state_dict = mod2normal(state_dict)
elif "model.0.weight" not in state_dict:
raise Exception("The file is not a recognized ESRGAN model.")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
model.load_state_dict(state_dict)
model.eval()
return model
def esrgan_upscale(model, img): def esrgan_upscale(model, img):
......
This diff is collapsed.
import os import os
import facexlib
import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import paths, shared, devices, modelloader, errors from modules import paths, shared, devices, modelloader, errors
...@@ -41,6 +38,8 @@ def gfpgann(): ...@@ -41,6 +38,8 @@ def gfpgann():
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
import facexlib.detection.retinaface
if hasattr(facexlib.detection.retinaface, 'device'): if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file model_file_path = model_file
...@@ -81,8 +80,10 @@ gfpgan_constructor = None ...@@ -81,8 +80,10 @@ gfpgan_constructor = None
def setup_model(dirname): def setup_model(dirname):
try: try:
os.makedirs(model_path, exist_ok=True) os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer import gfpgan
from facexlib import detection, parsing # noqa: F401 import facexlib.detection
import facexlib.parsing
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
...@@ -111,7 +112,7 @@ def setup_model(dirname): ...@@ -111,7 +112,7 @@ def setup_model(dirname):
facexlib.parsing.load_file_from_url = facex_load_file_from_url2 facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname user_path = dirname
have_gfpgan = True have_gfpgan = True
gfpgan_constructor = GFPGANer gfpgan_constructor = gfpgan.GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self): def name(self):
......
...@@ -345,13 +345,11 @@ def prepare_environment(): ...@@ -345,13 +345,11 @@ def prepare_environment():
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
...@@ -408,15 +406,10 @@ def prepare_environment(): ...@@ -408,15 +406,10 @@ def prepare_environment():
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores") startup_timer.record("clone repositores")
if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
......
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import shutil import shutil
import importlib import importlib
...@@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale ...@@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
logger = logging.getLogger(__name__)
def load_file_from_url( def load_file_from_url(
url: str, url: str,
*, *,
...@@ -177,3 +181,15 @@ def load_upscalers(): ...@@ -177,3 +181,15 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list. # Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
) )
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if half:
model = model.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
...@@ -38,7 +38,6 @@ mute_sdxl_imports() ...@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
......
import os import os
import numpy as np from modules.upscaler_utils import upscale_with_model
from PIL import Image
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
...@@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
from realesrgan import RealESRGANer # noqa: F401
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True self.enable = True
self.scalers = [] self.scalers = []
scalers = self.load_models(path) scalers = get_realesrgan_models(self)
local_model_paths = self.find_models(ext_filter=[".pth"]) local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers: for scaler in scalers:
...@@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler):
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception:
errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
def do_upscale(self, img, path): def do_upscale(self, img, path):
if not self.enable: if not self.enable:
return img return img
...@@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( mod = modelloader.load_spandrel_model(
scale=info.scale, info.local_data_path,
model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
)
return upscale_with_model(
mod,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
# TODO: `outscale`?
) )
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path): def load_model(self, path):
for scaler in self.scalers: for scaler in self.scalers:
...@@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler return scaler
raise ValueError(f"Unable to find model info: {path}") raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler): def get_realesrgan_models(scaler: UpscalerRealESRGAN):
try: return [
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
UpscalerData( UpscalerData(
name="R-ESRGAN General 4xV3", name="R-ESRGAN General 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN General WDN 4xV3", name="R-ESRGAN General WDN 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+", name="R-ESRGAN 4x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+ Anime6B", name="R-ESRGAN 4x+ Anime6B",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 2x+", name="R-ESRGAN 2x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
scale=2, scale=2,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
), ),
] ]
return models
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
...@@ -26,11 +26,9 @@ environment_whitelist = { ...@@ -26,11 +26,9 @@ environment_whitelist = {
"OPENCLIP_PACKAGE", "OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO", "STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO", "K_DIFFUSION_REPO",
"CODEFORMER_REPO",
"BLIP_REPO", "BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH",
"CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH", "BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS", "COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS", "IGNORE_CMD_ARGS_ERRORS",
......
...@@ -98,6 +98,9 @@ class UpscalerData: ...@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale self.scale = scale
self.model = model self.model = model
def __repr__(self):
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
class UpscalerNone(Upscaler): class UpscalerNone(Upscaler):
name = "None" name = "None"
......
...@@ -5,6 +5,7 @@ basicsr==1.4.2 ...@@ -5,6 +5,7 @@ basicsr==1.4.2
blendmodes==2022 blendmodes==2022
clean-fid==0.1.35 clean-fid==0.1.35
einops==0.4.1 einops==0.4.1
facexlib==0.3.0
fastapi==0.94.0 fastapi==0.94.0
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.41.2 gradio==3.41.2
...@@ -19,11 +20,10 @@ open-clip-torch==2.20.0 ...@@ -19,11 +20,10 @@ open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
psutil==5.9.5 psutil==5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
realesrgan==0.3.0
resize-right==0.0.2 resize-right==0.0.2
safetensors==0.3.1 safetensors==0.3.1
scikit-image==0.21.0 scikit-image==0.21.0
timm==0.9.2 spandrel==0.1.6
tomesd==0.1.3 tomesd==0.1.3
torch torch
torchdiffeq==0.2.3 torchdiffeq==0.2.3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment