Commit 3f7f61e5 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14524 from akx/fix-swinir-issues

Fix SwinIR issues
parents 1e7a8ce5 62470ee2
import logging import logging
import sys import sys
import torch
from PIL import Image from PIL import Image
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
...@@ -50,7 +51,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -50,7 +51,7 @@ class UpscalerSwinIR(Upscaler):
model, model,
tile_size=shared.opts.SWIN_tile, tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap, tile_overlap=shared.opts.SWIN_tile_overlap,
scale=4, # TODO: This was hard-coded before too... scale=model.scale,
desc="SwinIR", desc="SwinIR",
) )
devices.torch_gc() devices.torch_gc()
...@@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler):
model_descriptor = modelloader.load_spandrel_model( model_descriptor = modelloader.load_spandrel_model(
filename, filename,
device=self._get_device(), device=self._get_device(),
dtype=devices.dtype, prefer_half=(devices.dtype == torch.float16),
expected_architecture="SwinIR", expected_architecture="SwinIR",
) )
if getattr(shared.opts, 'SWIN_torch_compile', False): if getattr(shared.opts, 'SWIN_torch_compile', False):
......
...@@ -94,6 +94,7 @@ def tiled_upscale_2( ...@@ -94,6 +94,7 @@ def tiled_upscale_2(
tile_size: int, tile_size: int,
tile_overlap: int, tile_overlap: int,
scale: int, scale: int,
device: torch.device,
desc="Tiled upscale", desc="Tiled upscale",
): ):
# Alternative implementation of `upscale_with_model` originally used by # Alternative implementation of `upscale_with_model` originally used by
...@@ -101,9 +102,6 @@ def tiled_upscale_2( ...@@ -101,9 +102,6 @@ def tiled_upscale_2(
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting. # Pillow space without weighting.
# Grab the device the model is on, and use it.
device = torch_utils.get_param(model).device
b, c, h, w = img.size() b, c, h, w = img.size()
tile_size = min(tile_size, h, w) tile_size = min(tile_size, h, w)
...@@ -175,7 +173,8 @@ def upscale_2( ...@@ -175,7 +173,8 @@ def upscale_2(
""" """
Convenience wrapper around `tiled_upscale_2` that handles PIL images. Convenience wrapper around `tiled_upscale_2` that handles PIL images.
""" """
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension param = torch_utils.get_param(model)
tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
with torch.no_grad(): with torch.no_grad():
output = tiled_upscale_2( output = tiled_upscale_2(
...@@ -185,5 +184,6 @@ def upscale_2( ...@@ -185,5 +184,6 @@ def upscale_2(
tile_overlap=tile_overlap, tile_overlap=tile_overlap,
scale=scale, scale=scale,
desc=desc, desc=desc,
device=param.device,
) )
return torch_bgr_to_pil_image(output) return torch_bgr_to_pil_image(output)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment