Commit 19ba176d authored by deggua's avatar deggua

Fixed RGB/BGR colorspace mismatch in GFPGAN implementation, added GFPGAN...

Fixed RGB/BGR colorspace mismatch in GFPGAN implementation, added GFPGAN strength control to img2img/txt2img modes
parent 021ba36a
import argparse import argparse
from doctest import OutputChecker
import os import os
from pydoc import visiblename from pydoc import visiblename
import sys import sys
...@@ -639,7 +640,7 @@ class EmbeddingsWithFixes(nn.Module): ...@@ -639,7 +640,7 @@ class EmbeddingsWithFixes(nn.Module):
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None): def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, strength_GFPGAN = 1.0, do_not_save_grid=False, extra_generation_params=None):
self.outpath: str = outpath self.outpath: str = outpath
self.prompt: str = prompt self.prompt: str = prompt
self.seed: int = seed self.seed: int = seed
...@@ -652,6 +653,7 @@ class StableDiffusionProcessing: ...@@ -652,6 +653,7 @@ class StableDiffusionProcessing:
self.height: int = height self.height: int = height
self.prompt_matrix: bool = prompt_matrix self.prompt_matrix: bool = prompt_matrix
self.use_GFPGAN: bool = use_GFPGAN self.use_GFPGAN: bool = use_GFPGAN
self.strength_GFPGAN: bool = strength_GFPGAN
self.do_not_save_grid: bool = do_not_save_grid self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params self.extra_generation_params: dict = extra_generation_params
...@@ -785,18 +787,24 @@ def process_images(p: StableDiffusionProcessing): ...@@ -785,18 +787,24 @@ def process_images(p: StableDiffusionProcessing):
if p.prompt_matrix or opts.samples_save or opts.grid_save: if p.prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
# TODO: convert to BGR colorspace?
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
x_sample_bgr = x_sample[:,:,::-1]
if p.use_GFPGAN and GFPGAN is not None: if p.use_GFPGAN and GFPGAN is not None and p.strength_GFPGAN > 0.0:
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = GFPGAN.enhance(x_sample_bgr, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img gfpgan_output_rgb = gfpgan_output_bgr[:,:,::-1]
output_image = Image.fromarray(gfpgan_output_rgb)
image = Image.fromarray(x_sample) if p.strength_GFPGAN < 1.0:
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext()) output_image = Image.blend(Image.fromarray(x_sample), output_image, p.strength_GFPGAN)
else:
output_image = Image.fromarray(x_sample)
output_images.append(image) save_image(output_image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
output_images.append(output_image)
base_count += 1 base_count += 1
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid: if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid:
...@@ -832,7 +840,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -832,7 +840,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples_ddim return samples_ddim
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, strength_GFPGAN: float, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples" outpath = opts.outdir or "outputs/txt2img-samples"
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
...@@ -847,7 +855,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, ...@@ -847,7 +855,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
width=width, width=width,
height=height, height=height,
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN use_GFPGAN=use_GFPGAN,
strength_GFPGAN=strength_GFPGAN
) )
output_images, seed, info = process_images(p) output_images, seed, info = process_images(p)
...@@ -943,7 +952,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -943,7 +952,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
return samples_ddim return samples_ddim
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, strength_GFPGAN: float, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples" outpath = opts.outdir or "outputs/img2img-samples"
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
...@@ -961,6 +970,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG ...@@ -961,6 +970,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
height=height, height=height,
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN, use_GFPGAN=use_GFPGAN,
strength_GFPGAN=strength_GFPGAN,
init_images=[init_img], init_images=[init_img],
resize_mode=resize_mode, resize_mode=resize_mode,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
...@@ -1075,8 +1085,10 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in ...@@ -1075,8 +1085,10 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
outpath = opts.outdir or "outputs/extras-samples" outpath = opts.outdir or "outputs/extras-samples"
if GFPGAN is not None and GFPGAN_strength > 0: if GFPGAN is not None and GFPGAN_strength > 0:
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) img_data_bgr = np.array(image, dtype=np.uint8)[:,:,::-1]
res = Image.fromarray(restored_img) cropped_faces, restored_faces, restored_img = GFPGAN.enhance(img_data_bgr, has_aligned=False, only_center_face=False, paste_back=True)
img_data_rgb = restored_img[:,:,::-1]
res = Image.fromarray(img_data_rgb)
if GFPGAN_strength < 1.0: if GFPGAN_strength < 1.0:
res = Image.blend(image, res, GFPGAN_strength) res = Image.blend(image, res, GFPGAN_strength)
...@@ -1091,7 +1103,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in ...@@ -1091,7 +1103,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True) save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)
return image, 0, '' return [image], 0, ''
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
...@@ -1166,6 +1178,7 @@ def do_generate( ...@@ -1166,6 +1178,7 @@ def do_generate(
ddim_steps=sampler_steps, ddim_steps=sampler_steps,
sampler_index=sampler_index, sampler_index=sampler_index,
use_GFPGAN=facefix, use_GFPGAN=facefix,
strength_GFPGAN=facefix_strength,
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
n_iter=batch_count, n_iter=batch_count,
batch_size=batch_size, batch_size=batch_size,
...@@ -1181,6 +1194,7 @@ def do_generate( ...@@ -1181,6 +1194,7 @@ def do_generate(
ddim_steps=sampler_steps, ddim_steps=sampler_steps,
sampler_index=sampler_index, sampler_index=sampler_index,
use_GFPGAN=facefix, use_GFPGAN=facefix,
strength_GFPGAN=facefix_strength,
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
loopback=loopback, loopback=loopback,
sd_upscale=upscale, sd_upscale=upscale,
...@@ -1338,23 +1352,23 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1338,23 +1352,23 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
gr.Slider(label='Denoising Strength (DNS)', value=0.75, minimum=0.0, maximum=1.0, step=0.01, visible=False) gr.Slider(label='Denoising Strength (DNS)', value=0.75, minimum=0.0, maximum=1.0, step=0.01, visible=False)
sd_facefix = \ sd_facefix = \
gr.Checkbox(label='GFPGAN', value=False, visible=GFPGAN is not None).style(rounded=False) gr.Checkbox(label='GFPGAN', value=False, visible=GFPGAN is not None)
sd_facefix_strength = \ sd_facefix_strength = \
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None, visible=False) gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None, visible=False)
sd_use_input_seed = \ sd_use_input_seed = \
gr.Checkbox(label='Custom seed').style(rounded=False) gr.Checkbox(label='Custom seed')
sd_input_seed = \ sd_input_seed = \
gr.Number(value=-1, visible=False, show_label=False) gr.Number(value=-1, visible=False, show_label=False)
# TODO: Change to 'Enable syntactic prompts' # TODO: Change to 'Enable syntactic prompts'
sd_matrix = \ sd_matrix = \
gr.Checkbox(label='Create prompt matrix', value=False).style(rounded=False) gr.Checkbox(label='Create prompt matrix', value=False)
sd_loopback = \ sd_loopback = \
gr.Checkbox(label='Output loopback', value=False, visible=False).style(rounded=False) gr.Checkbox(label='Output loopback', value=False, visible=False)
sd_upscale = \ sd_upscale = \
gr.Checkbox(label='Super resolution upscale', value=False, visible=False).style(rounded=False) gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'): with gr.TabItem('Settings', id='settings_tab'):
# TODO: fix this # TODO: fix this
...@@ -1369,7 +1383,7 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1369,7 +1383,7 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
description=None, description=None,
allow_flagging="never") allow_flagging="never")
def mode_change(mode: str): def mode_change(mode: str, facefix: bool, custom_seed: bool):
is_img2img = (mode == 'Image-to-Image') is_img2img = (mode == 'Image-to-Image')
is_txt2img = (mode == 'Text-to-Image') is_txt2img = (mode == 'Text-to-Image')
is_pp = (mode == 'Post-Processing') is_pp = (mode == 'Post-Processing')
...@@ -1387,10 +1401,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1387,10 +1401,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
sd_image_width: gr.update(visible=is_img2img or is_txt2img), sd_image_width: gr.update(visible=is_img2img or is_txt2img),
sd_use_input_seed: gr.update(visible=is_img2img or is_txt2img), sd_use_input_seed: gr.update(visible=is_img2img or is_txt2img),
# TODO: can we handle this by updating use_input_seed and having its callback handle it? # TODO: can we handle this by updating use_input_seed and having its callback handle it?
sd_input_seed: gr.update(visible=False), sd_input_seed: gr.update(visible=(is_img2img or is_txt2img) and custom_seed),
sd_facefix: gr.update(visible=True), sd_facefix: gr.update(visible=True),
# TODO: see above, but for facefix # TODO: see above, but for facefix
sd_facefix_strength: gr.update(visible=False), sd_facefix_strength: gr.update(visible=facefix),
sd_matrix: gr.update(visible=is_img2img or is_txt2img), sd_matrix: gr.update(visible=is_img2img or is_txt2img),
sd_loopback: gr.update(visible=is_img2img), sd_loopback: gr.update(visible=is_img2img),
sd_upscale: gr.update(visible=is_img2img) sd_upscale: gr.update(visible=is_img2img)
...@@ -1398,7 +1412,11 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1398,7 +1412,11 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
sd_mode.change( sd_mode.change(
fn=mode_change, fn=mode_change,
inputs=sd_mode, inputs=[
sd_mode,
sd_facefix,
sd_use_input_seed
],
outputs=[ outputs=[
sd_cfg, sd_cfg,
sd_denoise, sd_denoise,
...@@ -1482,4 +1500,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1482,4 +1500,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
outputs=sd_batch_size outputs=sd_batch_size
) )
sd_facefix.change(
lambda checked : gr.update(visible=checked),
inputs=sd_facefix,
outputs=sd_facefix_strength
)
demo.launch() demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment