Commit 51f1cca8 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14484 from akx/swinir-resample-for-div8

Refactor Torch-space upscale fully out of ScuNET/SwinIR
parents 980970d3 cf14a6a7
import sys import sys
import PIL.Image import PIL.Image
import numpy as np
import torch
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
from modules.shared import opts
from modules.upscaler_utils import tiled_upscale_2
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc() devices.torch_gc()
try: try:
model = self.load_model(selected_file) model = self.load_model(selected_file)
except Exception as e: except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') img = upscaler_utils.upscale_2(
tile = opts.SCUNET_tile img,
h, w = img.height, img.width model,
np_img = np.array(img) tile_size=shared.opts.SCUNET_tile,
np_img = np_img[:, :, ::-1] # RGB to BGR tile_overlap=shared.opts.SCUNET_tile_overlap,
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW scale=1, # ScuNET is a denoising model, not an upscaler
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore desc='ScuNET',
)
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
with torch.no_grad():
torch_output = tiled_upscale_2(
torch_img,
model,
tile_size=opts.SCUNET_tile,
tile_overlap=opts.SCUNET_tile_overlap,
scale=1,
device=devices.get_device_for('scunet'),
desc="ScuNET tiles",
).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
devices.torch_gc() devices.torch_gc()
return img
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
...@@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
......
import logging import logging
import sys import sys
import numpy as np
import torch
from PIL import Image from PIL import Image
from modules import modelloader, devices, script_callbacks, shared from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
...@@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
current_config = (model_file, opts.SWIN_tile) current_config = (model_file, shared.opts.SWIN_tile)
device = self._get_device()
if self._cached_model_config == current_config: if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
...@@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler): ...@@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler):
self._cached_model = model self._cached_model = model
self._cached_model_config = current_config self._cached_model_config = current_config
img = upscale( img = upscaler_utils.upscale_2(
img, img,
model, model,
tile=opts.SWIN_tile, tile_size=shared.opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap, tile_overlap=shared.opts.SWIN_tile_overlap,
device=device, scale=4, # TODO: This was hard-coded before too...
desc="SwinIR",
) )
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler):
dtype=devices.dtype, dtype=devices.dtype,
expected_architecture="SwinIR", expected_architecture="SwinIR",
) )
if getattr(opts, 'SWIN_torch_compile', False): if getattr(shared.opts, 'SWIN_torch_compile', False):
try: try:
model_descriptor.model.compile() model_descriptor.model.compile()
except Exception: except Exception:
...@@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler): ...@@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler):
return devices.get_device_for('swinir') return devices.get_device_for('swinir')
def upscale(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size=8,
scale=4,
device,
):
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = tiled_upscale_2(
img,
model,
tile_size=tile,
tile_overlap=tile_overlap,
scale=scale,
device=device,
desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr
......
...@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils ...@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def upscale_without_tiling(model, img: Image.Image): def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
img = np.array(img) img = np.array(img.convert("RGB"))
img = img[:, :, ::-1] img = img[:, :, ::-1] # flip RGB to BGR
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.transpose(img, (2, 0, 1)) # HWC to CHW
img = torch.from_numpy(img).float() img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
return torch.from_numpy(img)
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
if tensor.ndim == 4:
# If we're given a tensor with a batch dimension, squeeze it out
# (but only if it's a batch of size 1).
if tensor.shape[0] != 1:
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
tensor = tensor.squeeze(0)
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
arr = arr.astype(np.uint8)
arr = arr[:, :, ::-1] # flip BGR to RGB
return Image.fromarray(arr, "RGB")
def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
"""
Upscale a given PIL image using the given model.
"""
param = torch_utils.get_param(model) param = torch_utils.get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad(): with torch.no_grad():
output = model(img) tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
tensor = tensor.to(device=param.device, dtype=param.dtype)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() return torch_bgr_to_pil_image(model(tensor))
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def upscale_with_model( def upscale_with_model(
...@@ -40,7 +57,7 @@ def upscale_with_model( ...@@ -40,7 +57,7 @@ def upscale_with_model(
) -> Image.Image: ) -> Image.Image:
if tile_size <= 0: if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img) logger.debug("Upscaling %s without tiling", img)
output = upscale_without_tiling(model, img) output = upscale_pil_patch(model, img)
logger.debug("=> %s", output) logger.debug("=> %s", output)
return output return output
...@@ -52,7 +69,7 @@ def upscale_with_model( ...@@ -52,7 +69,7 @@ def upscale_with_model(
newrow = [] newrow = []
for x, w, tile in row: for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile) logger.debug("Tile (%d, %d) %s...", x, y, tile)
output = upscale_without_tiling(model, tile) output = upscale_pil_patch(model, tile)
scale_factor = output.width // tile.width scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor) logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output]) newrow.append([x * scale_factor, w * scale_factor, output])
...@@ -71,19 +88,22 @@ def upscale_with_model( ...@@ -71,19 +88,22 @@ def upscale_with_model(
def tiled_upscale_2( def tiled_upscale_2(
img, img: torch.Tensor,
model, model,
*, *,
tile_size: int, tile_size: int,
tile_overlap: int, tile_overlap: int,
scale: int, scale: int,
device,
desc="Tiled upscale", desc="Tiled upscale",
): ):
# Alternative implementation of `upscale_with_model` originally used by # Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting. # Pillow space without weighting.
# Grab the device the model is on, and use it.
device = torch_utils.get_param(model).device
b, c, h, w = img.size() b, c, h, w = img.size()
tile_size = min(tile_size, h, w) tile_size = min(tile_size, h, w)
...@@ -100,7 +120,8 @@ def tiled_upscale_2( ...@@ -100,7 +120,8 @@ def tiled_upscale_2(
h * scale, h * scale,
w * scale, w * scale,
device=device, device=device,
).type_as(img) dtype=img.dtype,
)
weights = torch.zeros_like(result) weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
...@@ -112,11 +133,13 @@ def tiled_upscale_2( ...@@ -112,11 +133,13 @@ def tiled_upscale_2(
if shared.state.interrupted or shared.state.skipped: if shared.state.interrupted or shared.state.skipped:
break break
# Only move this patch to the device if it's not already there.
in_patch = img[ in_patch = img[
..., ...,
h_idx : h_idx + tile_size, h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size, w_idx : w_idx + tile_size,
] ].to(device=device)
out_patch = model(in_patch) out_patch = model(in_patch)
result[ result[
...@@ -138,3 +161,29 @@ def tiled_upscale_2( ...@@ -138,3 +161,29 @@ def tiled_upscale_2(
output = result.div_(weights) output = result.div_(weights)
return output return output
def upscale_2(
img: Image.Image,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
desc: str,
):
"""
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
"""
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
with torch.no_grad():
output = tiled_upscale_2(
tensor,
model,
tile_size=tile_size,
tile_overlap=tile_overlap,
scale=scale,
desc=desc,
)
return torch_bgr_to_pil_image(output)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment