Commit e472383a authored by Aarni Koskela's avatar Aarni Koskela

Refactor esrgan_upscale to more generic upscale_with_model

parent 12c6f37f
import sys import sys
import numpy as np
import torch import torch
from PIL import Image
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices from modules import modelloader, devices
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict): def mod2normal(state_dict):
...@@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler): ...@@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler):
return model return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def esrgan_upscale(model, img): def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: return upscale_with_model(
return upscale_without_tiling(model, img) model,
img,
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) tile_size=opts.ESRGAN_tile,
newtiles = [] tile_overlap=opts.ESRGAN_tile_overlap,
scale_factor = 1 )
for y, h, row in grid.tiles:
newrow = []
for tiledata in row:
x, w, tile = tiledata
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output
import logging
from typing import Callable
import numpy as np
import torch
import tqdm
from PIL import Image
from modules import devices, images
logger = logging.getLogger(__name__)
def upscale_without_tiling(model, img: Image.Image):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def upscale_with_model(
model: Callable[[torch.Tensor], torch.Tensor],
img: Image.Image,
*,
tile_size: int,
tile_overlap: int = 0,
desc="tiled upscale",
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
output = upscale_without_tiling(model, img)
logger.debug("=> %s", output)
return output
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
newtiles = []
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
for y, h, row in grid.tiles:
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
p.update(1)
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(
newtiles,
tile_w=grid.tile_w * scale_factor,
tile_h=grid.tile_h * scale_factor,
image_w=grid.image_w * scale_factor,
image_h=grid.image_h * scale_factor,
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment