Commit b0f59342 authored by Aarni Koskela's avatar Aarni Koskela

Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR)

parent e472383a
...@@ -7,9 +7,7 @@ from tqdm import tqdm ...@@ -7,9 +7,7 @@ from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
......
This diff is collapsed.
import logging
import sys import sys
import platform
import numpy as np import numpy as np
import torch import torch
...@@ -8,13 +8,11 @@ from tqdm import tqdm ...@@ -8,13 +8,11 @@ from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile) current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: device = self._get_device()
if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) self._cached_model = model
if use_compile: self._cached_model_config = current_config
model = torch.compile(model)
self._cached_model = model img = upscale(
self._cached_model_config = current_config img,
img = upscale(img, model) model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
device=device,
)
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: dtype=devices.dtype,
model.load_state_dict(pretrained_model, strict=True) )
if getattr(opts, 'SWIN_torch_compile', False):
try:
model = torch.compile(model)
except Exception:
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model return model
def _get_device(self):
return devices.get_device_for('swinir')
def upscale( def upscale(
img, img,
model, model,
tile=None, *,
tile_overlap=None, tile: int,
window_size=8, tile_overlap: int,
scale=4, window_size=8,
scale=4,
device,
): ):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale) output = inference(
img,
model,
tile=tile,
tile_overlap=tile_overlap,
window_size=window_size,
scale=scale,
device=device,
)
output = output[..., : h_old * scale, : w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3: if output.ndim == 3:
...@@ -142,7 +128,16 @@ def upscale( ...@@ -142,7 +128,16 @@ def upscale(
return Image.fromarray(output, "RGB") return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale): def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile # test the image tile by tile
b, c, h, w = img.size() b, c, h, w = img.size()
tile = min(tile, h, w) tile = min(tile, h, w)
...@@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
...@@ -185,8 +180,7 @@ def on_ui_settings(): ...@@ -185,8 +180,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import sys from modules import modelloader, devices, errors
import torch
import modules.esrgan_model_arch as arch
from modules import modelloader, devices
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
crt_net['model.3.weight'] = state_dict['upconv1.weight']
crt_net['model.3.bias'] = state_dict['upconv1.bias']
crt_net['model.6.weight'] = state_dict['upconv2.weight']
crt_net['model.6.bias'] = state_dict['upconv2.bias']
crt_net['model.8.weight'] = state_dict['HRconv.weight']
crt_net['model.8.bias'] = state_dict['HRconv.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if "rdb" in k:
ori_k = k.replace('body.', 'model.1.sub.')
ori_k = ori_k.replace('.rdb', '.RDB')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if (part_num > scalemin
and parts[0] == "model"
and parts[2] == "weight"):
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2 ** scale2x
return in_nc, out_nc, nf, nb, plus, scale
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ESRGAN" self.name = "ESRGAN"
...@@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler): ...@@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
try: try:
model = self.load_model(selected_model) model = self.load_model(selected_model)
except Exception as e: except Exception:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) return esrgan_upscale(model, img)
return img
def load_model(self, path: str): def load_model(self, path: str):
if path.startswith("http"): if path.startswith("http"):
...@@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler): ...@@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler):
else: else:
filename = path filename = path
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) return modelloader.load_spandrel_model(
filename,
if "params_ema" in state_dict: device=('cpu' if devices.device_esrgan.type == 'mps' else None),
state_dict = state_dict["params_ema"] )
elif "params" in state_dict:
state_dict = state_dict["params"]
num_conv = 16 if "realesr-animevideov3" in filename else 32
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
model.load_state_dict(state_dict)
model.eval()
return model
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
state_dict = resrgan2normal(state_dict, nb)
elif "conv_first.weight" in state_dict:
state_dict = mod2normal(state_dict)
elif "model.0.weight" not in state_dict:
raise Exception("The file is not a recognized ESRGAN model.")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
model.load_state_dict(state_dict)
model.eval()
return model
def esrgan_upscale(model, img): def esrgan_upscale(model, img):
......
This diff is collapsed.
import os import os
import facexlib
import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import paths, shared, devices, modelloader, errors from modules import paths, shared, devices, modelloader, errors
...@@ -41,6 +38,8 @@ def gfpgann(): ...@@ -41,6 +38,8 @@ def gfpgann():
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
import facexlib.detection.retinaface
if hasattr(facexlib.detection.retinaface, 'device'): if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file model_file_path = model_file
...@@ -81,8 +80,10 @@ gfpgan_constructor = None ...@@ -81,8 +80,10 @@ gfpgan_constructor = None
def setup_model(dirname): def setup_model(dirname):
try: try:
os.makedirs(model_path, exist_ok=True) os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer import gfpgan
from facexlib import detection, parsing # noqa: F401 import facexlib.detection
import facexlib.parsing
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
...@@ -111,7 +112,7 @@ def setup_model(dirname): ...@@ -111,7 +112,7 @@ def setup_model(dirname):
facexlib.parsing.load_file_from_url = facex_load_file_from_url2 facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname user_path = dirname
have_gfpgan = True have_gfpgan = True
gfpgan_constructor = GFPGANer gfpgan_constructor = gfpgan.GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self): def name(self):
......
...@@ -345,13 +345,11 @@ def prepare_environment(): ...@@ -345,13 +345,11 @@ def prepare_environment():
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
...@@ -408,15 +406,10 @@ def prepare_environment(): ...@@ -408,15 +406,10 @@ def prepare_environment():
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores") startup_timer.record("clone repositores")
if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
......
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import shutil import shutil
import importlib import importlib
...@@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale ...@@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
logger = logging.getLogger(__name__)
def load_file_from_url( def load_file_from_url(
url: str, url: str,
*, *,
...@@ -177,3 +181,15 @@ def load_upscalers(): ...@@ -177,3 +181,15 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list. # Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
) )
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if half:
model = model.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
...@@ -38,7 +38,6 @@ mute_sdxl_imports() ...@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
......
import os import os
import numpy as np from modules.upscaler_utils import upscale_with_model
from PIL import Image
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
...@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try: self.enable = True
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 self.scalers = []
from realesrgan import RealESRGANer # noqa: F401 scalers = get_realesrgan_models(self)
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
local_model_paths = self.find_models(ext_filter=[".pth"]) local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers: for scaler in scalers:
if scaler.local_data_path.startswith("http"): if scaler.local_data_path.startswith("http"):
filename = modelloader.friendly_name(scaler.local_data_path) filename = modelloader.friendly_name(scaler.local_data_path)
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
if local_model_candidates: if local_model_candidates:
scaler.local_data_path = local_model_candidates[0] scaler.local_data_path = local_model_candidates[0]
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception:
errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
def do_upscale(self, img, path): def do_upscale(self, img, path):
if not self.enable: if not self.enable:
...@@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( mod = modelloader.load_spandrel_model(
scale=info.scale, info.local_data_path,
model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
)
return upscale_with_model(
mod,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
# TODO: `outscale`?
) )
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path): def load_model(self, path):
for scaler in self.scalers: for scaler in self.scalers:
...@@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler return scaler
raise ValueError(f"Unable to find model info: {path}") raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler): def get_realesrgan_models(scaler: UpscalerRealESRGAN):
try: return [
from basicsr.archs.rrdbnet_arch import RRDBNet UpscalerData(
from realesrgan.archs.srvgg_arch import SRVGGNetCompact name="R-ESRGAN General 4xV3",
models = [ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
UpscalerData( scale=4,
name="R-ESRGAN General 4xV3", upscaler=scaler,
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", ),
scale=4, UpscalerData(
upscaler=scaler, name="R-ESRGAN General WDN 4xV3",
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
), scale=4,
UpscalerData( upscaler=scaler,
name="R-ESRGAN General WDN 4xV3", ),
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", UpscalerData(
scale=4, name="R-ESRGAN AnimeVideo",
upscaler=scaler, path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') scale=4,
), upscaler=scaler,
UpscalerData( ),
name="R-ESRGAN AnimeVideo", UpscalerData(
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", name="R-ESRGAN 4x+",
scale=4, path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
upscaler=scaler, scale=4,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') upscaler=scaler,
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+", name="R-ESRGAN 4x+ Anime6B",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) ),
), UpscalerData(
UpscalerData( name="R-ESRGAN 2x+",
name="R-ESRGAN 4x+ Anime6B", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", scale=2,
scale=4, upscaler=scaler,
upscaler=scaler, ),
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) ]
),
UpscalerData(
name="R-ESRGAN 2x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
return models
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
...@@ -26,11 +26,9 @@ environment_whitelist = { ...@@ -26,11 +26,9 @@ environment_whitelist = {
"OPENCLIP_PACKAGE", "OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO", "STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO", "K_DIFFUSION_REPO",
"CODEFORMER_REPO",
"BLIP_REPO", "BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH",
"CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH", "BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS", "COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS", "IGNORE_CMD_ARGS_ERRORS",
......
...@@ -98,6 +98,9 @@ class UpscalerData: ...@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale self.scale = scale
self.model = model self.model = model
def __repr__(self):
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
class UpscalerNone(Upscaler): class UpscalerNone(Upscaler):
name = "None" name = "None"
......
...@@ -5,6 +5,7 @@ basicsr==1.4.2 ...@@ -5,6 +5,7 @@ basicsr==1.4.2
blendmodes==2022 blendmodes==2022
clean-fid==0.1.35 clean-fid==0.1.35
einops==0.4.1 einops==0.4.1
facexlib==0.3.0
fastapi==0.94.0 fastapi==0.94.0
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.41.2 gradio==3.41.2
...@@ -19,11 +20,10 @@ open-clip-torch==2.20.0 ...@@ -19,11 +20,10 @@ open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
psutil==5.9.5 psutil==5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
realesrgan==0.3.0
resize-right==0.0.2 resize-right==0.0.2
safetensors==0.3.1 safetensors==0.3.1
scikit-image==0.21.0 scikit-image==0.21.0
timm==0.9.2 spandrel==0.1.6
tomesd==0.1.3 tomesd==0.1.3
torch torch
torchdiffeq==0.2.3 torchdiffeq==0.2.3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment