Commit 67a4549b authored by deggua's avatar deggua

Combined UI tabs into tab with a mode selection, UI renaming

parent c30aee2f
import argparse import argparse
import os import os
from pydoc import visiblename
import sys import sys
from collections import namedtuple from collections import namedtuple
import torch import torch
...@@ -53,20 +54,13 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em ...@@ -53,20 +54,13 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
SamplerData = namedtuple('SamplerData', ['name', 'constructor']) SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [ samplers = [
*[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [ *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'), ('Euler', 'sample_euler'),
('Euler ancestral', 'sample_euler_ancestral'), ('Euler ancestral', 'sample_euler_ancestral'),
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('DPM 2', 'sample_dpm_2'), ('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'), ('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])], ] if hasattr(k_diffusion.sampling, x[1])],
...@@ -858,7 +852,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, ...@@ -858,7 +852,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
output_images, seed, info = process_images(p) output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info) return output_images, plaintext_to_html(info)
class Flagging(gr.FlaggingCallback): class Flagging(gr.FlaggingCallback):
...@@ -901,32 +895,6 @@ class Flagging(gr.FlaggingCallback): ...@@ -901,32 +895,6 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0]) print("Logged:", filenames[0])
txt2img_interface = gr.Interface(
wrap_gradio_call(txt2img),
inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
],
outputs=[
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.HTML(),
],
title="Stable Diffusion Text-to-Image",
flagging_callback=Flagging()
)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None sampler = None
...@@ -1080,40 +1048,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG ...@@ -1080,40 +1048,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
else: else:
output_images, seed, info = process_images(p) output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info) return output_images, plaintext_to_html(info)
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
img2img_interface = gr.Interface(
wrap_gradio_call(img2img),
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Checkbox(label='Stable Diffusion upscale', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
],
outputs=[
gr.Gallery(),
gr.Number(label='Seed'),
gr.HTML(),
],
allow_flagging="never",
)
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index): def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
info = realesrgan_models[RealESRGAN_model_index] info = realesrgan_models[RealESRGAN_model_index]
...@@ -1158,23 +1093,6 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in ...@@ -1158,23 +1093,6 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
return image, 0, '' return image, 0, ''
extras_interface = gr.Interface(
wrap_gradio_call(run_extras),
inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None),
gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan),
gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan),
],
outputs=[
gr.Image(label="Result"),
gr.Number(label='Seed', visible=False),
gr.HTML(),
],
allow_flagging="never",
)
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
...@@ -1212,26 +1130,6 @@ def create_setting_component(key): ...@@ -1212,26 +1130,6 @@ def create_setting_component(key):
return item return item
settings_interface = gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never",
)
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img"),
(extras_interface, "Extras"),
(settings_interface, "Settings"),
]
sd_config = OmegaConf.load(cmd_opts.config) sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
...@@ -1241,13 +1139,298 @@ sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device) ...@@ -1241,13 +1139,298 @@ sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
model_hijack = StableDiffuionModelHijack() model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(sd_model) model_hijack.hijack(sd_model)
demo = gr.TabbedInterface( def do_generate(
interface_list=[x[0] for x in interfaces], mode: str,
tab_names=[x[1] for x in interfaces], prompt: str,
css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """ cfg: float,
denoise: float,
sampler_index: str,
sampler_steps: int,
batch_count: int,
batch_size: int,
input_img,
resize_mode,
image_height: int,
image_width: int,
use_input_seed: bool,
input_seed: int,
facefix: bool,
facefix_strength: float,
prompt_matrix: bool,
loopback: bool,
upscale: bool):
if mode == 'Text-to-Image':
return txt2img(
prompt=prompt,
ddim_steps=sampler_steps,
sampler_index=sampler_index,
use_GFPGAN=facefix,
prompt_matrix=prompt_matrix,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width
)
elif mode == 'Image-to-Image':
return img2img(
prompt=prompt,
init_img=input_img,
ddim_steps=sampler_steps,
sampler_index=sampler_index,
use_GFPGAN=facefix,
prompt_matrix=prompt_matrix,
loopback=loopback,
sd_upscale=upscale,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
denoising_strength=denoise,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width,
resize_mode=resize_mode
)
elif mode == 'Post-Processing':
return run_extras(
image=input_img,
GFPGAN_strength=facefix_strength,
RealESRGAN_upscaling=1.0,
RealESRGAN_model_index=0
)
raise Exception('Invalid mode selected')
css_hide_progressbar = \
"""
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
main_css = \
"""
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; } .performance { font-size: 0.85em; color: #444; }
""" """
)
#[data-testid="image"] {min-height: 512px !important}
custom_css = \
"""
#output_gallery {
min-height: 50vh !important;
scrollbar-width: none;
}
::-webkit-scrollbar {
display: none;
}
* #body>.col:nth-child(2){width:250%;max-width:89vw}
#generate{width: 100%; }
#prompt_row input{
font-size:16px
}
"""
full_css = main_css + css_hide_progressbar + custom_css
with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion WebUI') as demo:
with gr.Tabs(elem_id='tabs'):
with gr.TabItem('Stable Diffusion', id='txt2img_tab'):
with gr.Row(elem_id='prompt_row'):
sd_prompt = gr.Textbox(elem_id='prompt_input', placeholder='A corgi wearing a top hat as an oil painting.', lines=1, max_lines=1, show_label=False)
with gr.Row(elem_id='body').style(equal_height=False):
# Left Column
with gr.Column():
sd_mode = \
gr.Dropdown(label='Mode', value='Text-to-Image', choices=['Text-to-Image', 'Image-to-Image', 'Post-Processing'])
with gr.Row():
sd_image_height = \
gr.Number(label="Image Height", value=512, precision=0)
sd_image_width = \
gr.Number(label="Image Width", value=512, precision=0)
with gr.Row():
sd_batch_count = \
gr.Number(label='Batch count', precision=0, value=1)
sd_batch_size = \
gr.Number(label='Images per batch', precision=0, value=1)
with gr.Group():
sd_input_image = \
gr.Image(label='Input Image', source="upload", interactive=True, type="pil", show_label=False, visible=False)
sd_resize_mode = \
gr.Dropdown(label="Resize mode", choices=["Stretch", "Proportional stretch & crop", "Proportional stretch & fill"], type="index", value="Stretch", visible=False)
sd_generate = \
gr.Button('Generate').style(full_width=True)
# Center Column
with gr.Column():
sd_output_image = \
gr.Gallery(show_label=False, elem_id='output_gallery').style(grid=3)
sd_output_html = \
gr.HTML()
# Right Column
with gr.Column():
with gr.Row():
sd_sampling_method = \
gr.Dropdown(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index")
sd_sampling_steps = \
gr.Slider(label="Sampling steps", value=30, minimum=5, maximum=100, step=5)
with gr.Group():
sd_cfg = \
gr.Slider(label='Prompt similarity (CFG)', value=8.0, minimum=1.0, maximum=15.0, step=0.5)
sd_denoise = \
gr.Slider(label='Denoising Strength (DNS)', value=0.75, minimum=0.0, maximum=1.0, step=0.01, visible=False)
sd_use_input_seed = \
gr.Checkbox(label='Custom seed')
sd_input_seed = \
gr.Number(value=-1, visible=False, show_label=False)
sd_facefix = \
gr.Checkbox(label='GFPGAN', value=False, visible=GFPGAN is not None)
sd_facefix_strength = \
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None, visible=False)
sd_matrix = \
gr.Checkbox(label='Create prompt matrix', value=False)
sd_loopback = \
gr.Checkbox(label='Output loopback', value=False, visible=False)
sd_upscale = \
gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'):
# TODO: fix this
gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never")
def mode_change(mode: str):
is_img2img = (mode == 'Image-to-Image')
is_txt2img = (mode == 'Text-to-Image')
is_pp = (mode == 'Post-Processing')
return {
sd_cfg: gr.update(visible=is_img2img or is_txt2img),
sd_denoise: gr.update(visible=is_img2img),
sd_sampling_method: gr.update(visible=is_img2img or is_txt2img),
sd_sampling_steps: gr.update(visible=is_img2img or is_txt2img),
sd_batch_count: gr.update(visible=is_img2img or is_txt2img),
sd_batch_size: gr.update(visible=is_img2img or is_txt2img),
sd_input_image: gr.update(visible=is_img2img or is_pp),
sd_resize_mode: gr.update(visible=is_img2img),
sd_image_height: gr.update(visible=is_img2img or is_txt2img),
sd_image_width: gr.update(visible=is_img2img or is_txt2img),
sd_use_input_seed: gr.update(visible=is_img2img or is_txt2img),
# TODO: can we handle this by updating use_input_seed and having its callback handle it?
sd_input_seed: gr.update(visible=False),
sd_facefix: gr.update(visible=True),
# TODO: see above, but for facefix
sd_facefix_strength: gr.update(visible=False),
sd_matrix: gr.update(visible=is_img2img or is_txt2img),
sd_loopback: gr.update(visible=is_img2img),
sd_upscale: gr.update(visible=is_img2img)
}
sd_mode.change(
fn=mode_change,
inputs=sd_mode,
outputs=[
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
]
)
do_generate_args = dict(
fn=wrap_gradio_call(do_generate),
inputs=[
sd_mode,
sd_prompt,
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
],
outputs=[
sd_output_image,
sd_output_html
]
)
sd_prompt.submit(**do_generate_args)
sd_generate.click(**do_generate_args)
sd_use_input_seed.change(
lambda checked : gr.update(visible=checked),
inputs=sd_use_input_seed,
outputs=sd_input_seed
)
sd_image_height.submit(
lambda value : 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_height,
outputs=sd_image_height
)
sd_image_width.submit(
lambda value : 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_width,
outputs=sd_image_width
)
sd_batch_count.submit(
lambda value : value if value > 0 else 1,
inputs=sd_batch_count,
outputs=sd_batch_count
)
sd_batch_size.submit(
lambda value : value if value > 0 else 1,
inputs=sd_batch_size,
outputs=sd_batch_size
)
demo.launch() demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment