Commit 1fba573d authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #3874 from cobryan05/extra_tweak

Extras Tab - Option to upscale before face fix, caching improvements
parents 2338ed95 d8b36614
from __future__ import annotations
import math import math
import os import os
...@@ -7,6 +8,10 @@ from PIL import Image ...@@ -7,6 +8,10 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models from modules import processing, shared, images, devices, sd_models
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
...@@ -17,10 +22,38 @@ import piexif.helper ...@@ -17,10 +22,38 @@ import piexif.helper
import gradio as gr import gradio as gr
cached_images = {} class LruCache(OrderedDict):
@dataclass(frozen=True)
class Key:
image_hash: int
info_hash: int
args_hash: int
@dataclass
class Value:
image: Image.Image
info: str
def __init__(self, max_size: int = 5, *args, **kwargs):
super().__init__(*args, **kwargs)
self._max_size = max_size
def get(self, key: LruCache.Key) -> LruCache.Value:
ret = super().get(key)
if ret is not None:
self.move_to_end(key) # Move to end of eviction list
return ret
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
self[key] = value
while len(self) > self._max_size:
self.popitem(last=False)
cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
devices.torch_gc() devices.torch_gc()
imageArr = [] imageArr = []
...@@ -56,72 +89,102 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -56,72 +89,102 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else: else:
outpath = opts.outdir_samples or opts.outdir_extras_samples outpath = opts.outdir_samples or opts.outdir_extras_samples
# Extra operation definitions
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB") def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
info = "" restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
if gfpgan_visibility > 0: if gfpgan_visibility < 1.0:
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) res = Image.blend(image, res, gfpgan_visibility)
res = Image.fromarray(restored_img)
if gfpgan_visibility < 1.0: info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
res = Image.blend(image, res, gfpgan_visibility) return (res, info)
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
image = res restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
if codeformer_visibility > 0: if codeformer_visibility < 1.0:
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) res = Image.blend(image, res, codeformer_visibility)
res = Image.fromarray(restored_img)
if codeformer_visibility < 1.0: info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
res = Image.blend(image, res, codeformer_visibility) return (res, info)
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
image = res upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
res = cropped
return res
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
nonlocal upscaling_resize
if resize_mode == 1: if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else "" crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
return (image, info)
@dataclass
class UpscaleParams:
upscaler_idx: int
blend_alpha: float
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info),
args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
else:
res, info = cached_entry.image, cached_entry.info
if blended_result is None:
blended_result = res
else:
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
return (blended_result, info)
# Build a list of operations to run
facefix_ops: List[Callable] = []
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
upscale_ops: List[Callable] = []
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
upscale_ops.append(partial(run_upscalers_blend, step_params))
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB")
info = ""
# Run each operation on each image
for op in extras_ops:
image, info = op(image, info)
if upscaling_resize != 1.0:
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility)
image = res
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
if opts.use_original_name_batch and image_name != None: if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0] basename = os.path.splitext(os.path.basename(image_name))[0]
else: else:
...@@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), '' return outputs, plaintext_to_html(info), ''
def clear_cache():
cached_images.clear()
def run_pnginfo(image): def run_pnginfo(image):
if image is None: if image is None:
......
...@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call):
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
...@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_upscaler_1, extras_upscaler_1,
extras_upscaler_2, extras_upscaler_2,
extras_upscaler_2_visibility, extras_upscaler_2_visibility,
upscale_before_face_fix,
], ],
outputs=[ outputs=[
result_images, result_images,
...@@ -1174,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1174,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )
extras_image.change(
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
)
with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment