Commit cd12c0e1 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14425 from akx/spandrel

Use Spandrel for upscaling and face restoration architectures
parents 05230c02 4ad0c0c0
...@@ -20,6 +20,12 @@ jobs: ...@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py launch.py
- name: Cache models
id: cache-models
uses: actions/cache@v3
with:
path: models
key: "2023-12-30"
- name: Install test dependencies - name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt run: pip install wait-for-it -r requirements-test.txt
env: env:
...@@ -33,6 +39,8 @@ jobs: ...@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1" WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
- name: Print installed packages
run: pip freeze
- name: Start test server - name: Start test server
run: > run: >
python -m coverage run python -m coverage run
......
...@@ -37,3 +37,4 @@ notification.mp3 ...@@ -37,3 +37,4 @@ notification.mp3
/node_modules /node_modules
/package-lock.json /package-lock.json
/.coverage* /.coverage*
/test/test_outputs
...@@ -7,9 +7,7 @@ from tqdm import tqdm ...@@ -7,9 +7,7 @@ from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
......
This diff is collapsed.
import logging
import sys import sys
import platform
import numpy as np import numpy as np
import torch import torch
...@@ -8,13 +8,11 @@ from tqdm import tqdm ...@@ -8,13 +8,11 @@ from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile) current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: device = self._get_device()
if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype)
if use_compile:
model = torch.compile(model)
self._cached_model = model self._cached_model = model
self._cached_model_config = current_config self._cached_model_config = current_config
img = upscale(img, model)
img = upscale(
img,
model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
device=device,
)
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -69,69 +70,55 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,69 +70,55 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: dtype=devices.dtype,
model.load_state_dict(pretrained_model, strict=True) expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
try:
model = torch.compile(model)
except Exception:
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model return model
def _get_device(self):
return devices.get_device_for('swinir')
def upscale( def upscale(
img, img,
model, model,
tile=None, *,
tile_overlap=None, tile: int,
tile_overlap: int,
window_size=8, window_size=8,
scale=4, scale=4,
device,
): ):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale) output = inference(
img,
model,
tile=tile,
tile_overlap=tile_overlap,
window_size=window_size,
scale=scale,
device=device,
)
output = output[..., : h_old * scale, : w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3: if output.ndim == 3:
...@@ -142,7 +129,16 @@ def upscale( ...@@ -142,7 +129,16 @@ def upscale(
return Image.fromarray(output, "RGB") return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale): def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile # test the image tile by tile
b, c, h, w = img.size() b, c, h, w = img.size()
tile = min(tile, h, w) tile = min(tile, h, w)
...@@ -152,8 +148,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -152,8 +148,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
...@@ -185,7 +181,6 @@ def on_ui_settings(): ...@@ -185,7 +181,6 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import os from __future__ import annotations
import cv2 import logging
import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
codeformer = None import torch
from modules import (
devices,
errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)
def setup_model(dirname): logger = logging.getLogger(__name__)
os.makedirs(model_path, exist_ok=True)
path = modules.paths.paths.get("CodeFormer", None) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
if path is None: model_download_name = 'codeformer-v0.1.0.pth'
return
try: # used by e.g. postprocessing_codeformer.py
from torchvision.transforms.functional import normalize codeformer: face_restoration.FaceRestoration | None = None
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
def name(self): def name(self):
return "CodeFormer" return "CodeFormer"
def __init__(self, dirname): def load_net(self) -> torch.Module:
self.net = None for model_path in modelloader.load_models(
self.face_helper = None model_path=self.model_path,
self.cmd_dir = dirname model_url=model_url,
command_path=self.model_path,
def create_models(self): download_name=model_download_name,
ext_filter=['.pth'],
if self.net is not None and self.face_helper is not None: ):
self.net.to(devices.device_codeformer) return modelloader.load_spandrel_model(
return self.net, self.face_helper model_path,
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) device=devices.device_codeformer,
if len(model_paths) != 0: expected_architecture='CodeFormer',
ckpt_path = model_paths[0] ).model
else: raise ValueError("No codeformer model found")
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net
self.face_helper = face_helper
return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2] def get_device(self):
return devices.device_codeformer
self.create_models() def restore(self, np_image, w: float | None = None):
if self.net is None or self.face_helper is None: if w is None:
return np_image w = getattr(shared.opts, "code_former_weight", 0.5)
self.send_model_to(devices.device_codeformer) def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, w=w, adain=True)[0]
self.face_helper.clean_all() return self.restore_with_helper(np_image, restore_face)
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
self.face_helper.get_inverse_affine(None)
restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)
return restored_img
def setup_model(dirname: str) -> None:
global codeformer global codeformer
try:
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
errors.report("Error setting up CodeFormer", exc_info=True) errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
import sys from modules import modelloader, devices, errors
import numpy as np
import torch
from PIL import Image
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
crt_net['model.3.weight'] = state_dict['upconv1.weight']
crt_net['model.3.bias'] = state_dict['upconv1.bias']
crt_net['model.6.weight'] = state_dict['upconv2.weight']
crt_net['model.6.bias'] = state_dict['upconv2.bias']
crt_net['model.8.weight'] = state_dict['HRconv.weight']
crt_net['model.8.bias'] = state_dict['HRconv.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if "rdb" in k:
ori_k = k.replace('body.', 'model.1.sub.')
ori_k = ori_k.replace('.rdb', '.RDB')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if (part_num > scalemin
and parts[0] == "model"
and parts[2] == "weight"):
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2 ** scale2x
return in_nc, out_nc, nf, nb, plus, scale
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
...@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler): ...@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
try: try:
model = self.load_model(selected_model) model = self.load_model(selected_model)
except Exception as e: except Exception:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) return esrgan_upscale(model, img)
return img
def load_model(self, path: str): def load_model(self, path: str):
if path.startswith("http"): if path.startswith("http"):
...@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler): ...@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
else: else:
filename = path filename = path
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) return modelloader.load_spandrel_model(
filename,
if "params_ema" in state_dict: device=('cpu' if devices.device_esrgan.type == 'mps' else None),
state_dict = state_dict["params_ema"] expected_architecture='ESRGAN',
elif "params" in state_dict: )
state_dict = state_dict["params"]
num_conv = 16 if "realesr-animevideov3" in filename else 32
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
model.load_state_dict(state_dict)
model.eval()
return model
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
state_dict = resrgan2normal(state_dict, nb)
elif "conv_first.weight" in state_dict:
state_dict = mod2normal(state_dict)
elif "model.0.weight" not in state_dict:
raise Exception("The file is not a recognized ESRGAN model.")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
model.load_state_dict(state_dict)
model.eval()
return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def esrgan_upscale(model, img): def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: return upscale_with_model(
return upscale_without_tiling(model, img) model,
img,
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) tile_size=opts.ESRGAN_tile,
newtiles = [] tile_overlap=opts.ESRGAN_tile_overlap,
scale_factor = 1 )
for y, h, row in grid.tiles:
newrow = []
for tiledata in row:
x, w, tile = tiledata
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output
This diff is collapsed.
from __future__ import annotations
import logging
import os
from functools import cached_property
from typing import TYPE_CHECKING, Callable
import cv2
import numpy as np
import torch
from modules import devices, errors, face_restoration, shared
if TYPE_CHECKING:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
logger = logging.getLogger(__name__)
def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
if hasattr(retinaface, 'device'):
retinaface.device = device
return FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=device,
)
def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
restore_face: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
`restore_face` should take a cropped face image and return a restored face image.
"""
from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
try:
logger.debug("Detecting faces...")
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
restored_face = tensor2img(
restore_face(cropped_face_t),
rgb2bgr=True,
min_max=(-1, 1),
)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
logger.debug("Merging restored faces into image")
face_helper.get_inverse_affine(None)
img = face_helper.paste_faces_to_input_image()
img = img[:, :, ::-1]
if original_resolution != img.shape[0:2]:
img = cv2.resize(
img,
(0, 0),
fx=original_resolution[1] / img.shape[1],
fy=original_resolution[0] / img.shape[0],
interpolation=cv2.INTER_LINEAR,
)
logger.debug("Face restoration complete")
finally:
face_helper.clean_all()
return img
class CommonFaceRestoration(face_restoration.FaceRestoration):
net: torch.Module | None
model_url: str
model_download_name: str
def __init__(self, model_path: str):
super().__init__()
self.net = None
self.model_path = model_path
os.makedirs(model_path, exist_ok=True)
@cached_property
def face_helper(self) -> FaceRestoreHelper:
return create_face_helper(self.get_device())
def send_model_to(self, device):
if self.net:
logger.debug("Sending %s to %s", self.net, device)
self.net.to(device)
if self.face_helper:
logger.debug("Sending face helper to %s", device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def get_device(self):
raise NotImplementedError("get_device must be implemented by subclasses")
def load_net(self) -> torch.Module:
raise NotImplementedError("load_net must be implemented by subclasses")
def restore_with_helper(
self,
np_image: np.ndarray,
restore_face: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
try:
if self.net is None:
self.net = self.load_net()
except Exception:
logger.warning("Unable to load face-restoration model", exc_info=True)
return np_image
try:
self.send_model_to(self.get_device())
return restore_with_face_helper(np_image, self.face_helper, restore_face)
finally:
if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)
def patch_facexlib(dirname: str) -> None:
import facexlib.detection
import facexlib.parsing
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
def update_kwargs(kwargs):
return dict(kwargs, save_dir=dirname, model_dir=None)
def facex_load_file_from_url(**kwargs):
return det_facex_load_file_from_url(**update_kwargs(kwargs))
def facex_load_file_from_url2(**kwargs):
return par_facex_load_file_from_url(**update_kwargs(kwargs))
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
import os from __future__ import annotations
import facexlib import logging
import gfpgan import os
import modules.face_restoration from modules import (
from modules import paths, shared, devices, modelloader, errors devices,
errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)
model_dir = "GFPGAN" logger = logging.getLogger(__name__)
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False model_download_name = "GFPGANv1.4.pth"
loaded_gfpgan_model = None gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
def gfpgann():
global loaded_gfpgan_model
global model_path
global model_file_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
def name(self):
if len(models) == 1 and models[0].startswith("http"): return "GFPGAN"
model_file = models[0]
elif len(models) != 0:
gfp_models = []
for item in models:
if 'GFPGAN' in os.path.basename(item):
gfp_models.append(item)
latest_file = max(gfp_models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model
return model def get_device(self):
return devices.device_gfpgan
def load_net(self) -> None:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
command_path=self.model_path,
download_name=model_download_name,
ext_filter=['.pth'],
):
if 'GFPGAN' in os.path.basename(model_path):
net = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net
raise ValueError("No GFPGAN model found")
def restore(self, np_image):
def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, return_rgb=False)[0]
def send_model_to(model, device): return self.restore_with_helper(np_image, restore_face)
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgann() if gfpgan_face_restorer:
if model is None: return gfpgan_face_restorer.restore(np_image)
return np_image logger.warning("GFPGAN face restorer not set up")
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload:
send_model_to(model, devices.cpu)
return np_image return np_image
gfpgan_constructor = None def setup_model(dirname: str) -> None:
global gfpgan_face_restorer
def setup_model(dirname):
try: try:
os.makedirs(model_path, exist_ok=True) face_restoration_utils.patch_facexlib(dirname)
from gfpgan import GFPGANer gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
from facexlib import detection, parsing # noqa: F401 shared.face_restorers.append(gfpgan_face_restorer)
global user_path
global have_gfpgan
global gfpgan_constructor
global model_file_path
facexlib_path = model_path
if dirname is not None:
facexlib_path = dirname
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
return "GFPGAN"
def restore(self, np_image):
return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
errors.report("Error setting up GFPGAN", exc_info=True) errors.report("Error setting up GFPGAN", exc_info=True)
import os
import sys
from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerHAT(Upscaler):
def __init__(self, dirname):
self.name = "HAT"
self.scalers = []
self.user_path = dirname
super().__init__()
for file in self.find_models(ext_filter=[".pt", ".pth"]):
name = modelloader.friendly_name(file)
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan) # TODO: should probably be device_hat
return upscale_with_model(
model,
img,
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
)
def load_model(self, path: str):
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file {path} not found")
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
)
...@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
return grid return grid
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
@property
def tile_count(self) -> int:
"""
The total number of tiles in the grid.
"""
return sum(len(row[2]) for row in self.tiles)
def split_grid(image, tile_w=512, tile_h=512, overlap=64): def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
w = image.width w, h = image.size
h = image.height
non_overlap_width = tile_w - overlap non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap non_overlap_height = tile_h - overlap
......
...@@ -345,13 +345,11 @@ def prepare_environment(): ...@@ -345,13 +345,11 @@ def prepare_environment():
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
...@@ -408,15 +406,10 @@ def prepare_environment(): ...@@ -408,15 +406,10 @@ def prepare_environment():
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores") startup_timer.record("clone repositores")
if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
......
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import shutil import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
import torch
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
logger = logging.getLogger(__name__)
def load_file_from_url( def load_file_from_url(
url: str, url: str,
*, *,
...@@ -177,3 +183,24 @@ def load_upscalers(): ...@@ -177,3 +183,24 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list. # Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
) )
def load_spandrel_model(
path: str,
*,
device: str | torch.device | None,
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half:
model = model.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
...@@ -38,7 +38,6 @@ mute_sdxl_imports() ...@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
......
import os import os
import numpy as np
from PIL import Image
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
...@@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
from realesrgan import RealESRGANer # noqa: F401
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True self.enable = True
self.scalers = [] self.scalers = []
scalers = self.load_models(path) scalers = get_realesrgan_models(self)
local_model_paths = self.find_models(ext_filter=[".pth"]) local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers: for scaler in scalers:
...@@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler):
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception:
errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
def do_upscale(self, img, path): def do_upscale(self, img, path):
if not self.enable: if not self.enable:
return img return img
...@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( mod = modelloader.load_spandrel_model(
scale=info.scale, info.local_data_path,
model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="RealESRGAN",
)
return upscale_with_model(
mod,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
# TODO: `outscale`?
) )
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path): def load_model(self, path):
for scaler in self.scalers: for scaler in self.scalers:
...@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler return scaler
raise ValueError(f"Unable to find model info: {path}") raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
def get_realesrgan_models(scaler): return [
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
UpscalerData( UpscalerData(
name="R-ESRGAN General 4xV3", name="R-ESRGAN General 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN General WDN 4xV3", name="R-ESRGAN General WDN 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+", name="R-ESRGAN 4x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+ Anime6B", name="R-ESRGAN 4x+ Anime6B",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 2x+", name="R-ESRGAN 2x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
scale=2, scale=2,
upscaler=scaler, upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
), ),
] ]
return models
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
...@@ -26,11 +26,9 @@ environment_whitelist = { ...@@ -26,11 +26,9 @@ environment_whitelist = {
"OPENCLIP_PACKAGE", "OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO", "STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO", "K_DIFFUSION_REPO",
"CODEFORMER_REPO",
"BLIP_REPO", "BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH",
"CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH", "BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS", "COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS", "IGNORE_CMD_ARGS_ERRORS",
......
...@@ -98,6 +98,9 @@ class UpscalerData: ...@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale self.scale = scale
self.model = model self.model = model
def __repr__(self):
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
class UpscalerNone(Upscaler): class UpscalerNone(Upscaler):
name = "None" name = "None"
......
import logging
from typing import Callable
import numpy as np
import torch
import tqdm
from PIL import Image
from modules import devices, images
logger = logging.getLogger(__name__)
def upscale_without_tiling(model, img: Image.Image):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def upscale_with_model(
model: Callable[[torch.Tensor], torch.Tensor],
img: Image.Image,
*,
tile_size: int,
tile_overlap: int = 0,
desc="tiled upscale",
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
output = upscale_without_tiling(model, img)
logger.debug("=> %s", output)
return output
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
newtiles = []
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
for y, h, row in grid.tiles:
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
p.update(1)
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(
newtiles,
tile_w=grid.tile_w * scale_factor,
tile_h=grid.tile_h * scale_factor,
image_w=grid.image_w * scale_factor,
image_h=grid.image_h * scale_factor,
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
...@@ -5,8 +5,8 @@ basicsr==1.4.2 ...@@ -5,8 +5,8 @@ basicsr==1.4.2
blendmodes==2022 blendmodes==2022
clean-fid==0.1.35 clean-fid==0.1.35
einops==0.4.1 einops==0.4.1
facexlib==0.3.0
fastapi==0.94.0 fastapi==0.94.0
gfpgan==1.3.8
gradio==3.41.2 gradio==3.41.2
httpcore==0.15 httpcore==0.15
inflection==0.5.1 inflection==0.5.1
...@@ -19,11 +19,10 @@ open-clip-torch==2.20.0 ...@@ -19,11 +19,10 @@ open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
psutil==5.9.5 psutil==5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
realesrgan==0.3.0
resize-right==0.0.2 resize-right==0.0.2
safetensors==0.3.1 safetensors==0.3.1
scikit-image==0.21.0 scikit-image==0.21.0
timm==0.9.2 spandrel==0.1.6
tomesd==0.1.3 tomesd==0.1.3
torch torch
torchdiffeq==0.2.3 torchdiffeq==0.2.3
......
import base64
import os import os
import pytest import pytest
import base64
test_files_path = os.path.dirname(__file__) + "/test_files" test_files_path = os.path.dirname(__file__) + "/test_files"
test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
def pytest_configure(config):
# We don't want to fail on Py.test command line arguments being
# parsed by webui:
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
def file_to_base64(filename): def file_to_base64(filename):
...@@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: ...@@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str:
@pytest.fixture(scope="session") # session so we don't read this over and over @pytest.fixture(scope="session") # session so we don't read this over and over
def mask_basic_image_base64() -> str: def mask_basic_image_base64() -> str:
return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
@pytest.fixture(scope="session")
def initialize() -> None:
import webui # noqa: F401
import os
from test.conftest import test_files_path, test_outputs_path
import numpy as np
import pytest
from PIL import Image
@pytest.mark.usefixtures("initialize")
@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
def test_face_restorers(restorer_name):
from modules import shared
if restorer_name == "gfpgan":
from modules import gfpgan_model
gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
restorer = gfpgan_model.gfpgan_fix_faces
elif restorer_name == "codeformer":
from modules import codeformer_model
codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
restorer = codeformer_model.codeformer.restore
else:
raise NotImplementedError("...")
img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
np_img = np.array(img, dtype=np.uint8)
fixed_image = restorer(np_img)
assert fixed_image.shape == np_img.shape
assert not np.allclose(fixed_image, np_img) # should have visibly changed
Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment