Commit 67a4549b authored by deggua's avatar deggua

Combined UI tabs into tab with a mode selection, UI renaming

parent c30aee2f
import argparse
import os
from pydoc import visiblename
import sys
from collections import namedtuple
import torch
......@@ -53,20 +54,13 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em
cmd_opts = parser.parse_args()
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
('Euler ancestral', 'sample_euler_ancestral'),
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
......@@ -858,7 +852,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
return output_images, plaintext_to_html(info)
class Flagging(gr.FlaggingCallback):
......@@ -901,32 +895,6 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0])
txt2img_interface = gr.Interface(
wrap_gradio_call(txt2img),
inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
],
outputs=[
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.HTML(),
],
title="Stable Diffusion Text-to-Image",
flagging_callback=Flagging()
)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
......@@ -1080,40 +1048,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
else:
output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
img2img_interface = gr.Interface(
wrap_gradio_call(img2img),
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Checkbox(label='Stable Diffusion upscale', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
],
outputs=[
gr.Gallery(),
gr.Number(label='Seed'),
gr.HTML(),
],
allow_flagging="never",
)
return output_images, plaintext_to_html(info)
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
info = realesrgan_models[RealESRGAN_model_index]
......@@ -1158,23 +1093,6 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
return image, 0, ''
extras_interface = gr.Interface(
wrap_gradio_call(run_extras),
inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None),
gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan),
gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan),
],
outputs=[
gr.Image(label="Result"),
gr.Number(label='Seed', visible=False),
gr.HTML(),
],
allow_flagging="never",
)
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
......@@ -1212,26 +1130,6 @@ def create_setting_component(key):
return item
settings_interface = gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never",
)
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img"),
(extras_interface, "Extras"),
(settings_interface, "Settings"),
]
sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
......@@ -1241,13 +1139,298 @@ sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(sd_model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces],
css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """
def do_generate(
mode: str,
prompt: str,
cfg: float,
denoise: float,
sampler_index: str,
sampler_steps: int,
batch_count: int,
batch_size: int,
input_img,
resize_mode,
image_height: int,
image_width: int,
use_input_seed: bool,
input_seed: int,
facefix: bool,
facefix_strength: float,
prompt_matrix: bool,
loopback: bool,
upscale: bool):
if mode == 'Text-to-Image':
return txt2img(
prompt=prompt,
ddim_steps=sampler_steps,
sampler_index=sampler_index,
use_GFPGAN=facefix,
prompt_matrix=prompt_matrix,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width
)
elif mode == 'Image-to-Image':
return img2img(
prompt=prompt,
init_img=input_img,
ddim_steps=sampler_steps,
sampler_index=sampler_index,
use_GFPGAN=facefix,
prompt_matrix=prompt_matrix,
loopback=loopback,
sd_upscale=upscale,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
denoising_strength=denoise,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width,
resize_mode=resize_mode
)
elif mode == 'Post-Processing':
return run_extras(
image=input_img,
GFPGAN_strength=facefix_strength,
RealESRGAN_upscaling=1.0,
RealESRGAN_model_index=0
)
raise Exception('Invalid mode selected')
css_hide_progressbar = \
"""
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
main_css = \
"""
.output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
"""
)
#[data-testid="image"] {min-height: 512px !important}
custom_css = \
"""
#output_gallery {
min-height: 50vh !important;
scrollbar-width: none;
}
::-webkit-scrollbar {
display: none;
}
* #body>.col:nth-child(2){width:250%;max-width:89vw}
#generate{width: 100%; }
#prompt_row input{
font-size:16px
}
"""
full_css = main_css + css_hide_progressbar + custom_css
with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion WebUI') as demo:
with gr.Tabs(elem_id='tabs'):
with gr.TabItem('Stable Diffusion', id='txt2img_tab'):
with gr.Row(elem_id='prompt_row'):
sd_prompt = gr.Textbox(elem_id='prompt_input', placeholder='A corgi wearing a top hat as an oil painting.', lines=1, max_lines=1, show_label=False)
with gr.Row(elem_id='body').style(equal_height=False):
# Left Column
with gr.Column():
sd_mode = \
gr.Dropdown(label='Mode', value='Text-to-Image', choices=['Text-to-Image', 'Image-to-Image', 'Post-Processing'])
with gr.Row():
sd_image_height = \
gr.Number(label="Image Height", value=512, precision=0)
sd_image_width = \
gr.Number(label="Image Width", value=512, precision=0)
with gr.Row():
sd_batch_count = \
gr.Number(label='Batch count', precision=0, value=1)
sd_batch_size = \
gr.Number(label='Images per batch', precision=0, value=1)
with gr.Group():
sd_input_image = \
gr.Image(label='Input Image', source="upload", interactive=True, type="pil", show_label=False, visible=False)
sd_resize_mode = \
gr.Dropdown(label="Resize mode", choices=["Stretch", "Proportional stretch & crop", "Proportional stretch & fill"], type="index", value="Stretch", visible=False)
sd_generate = \
gr.Button('Generate').style(full_width=True)
# Center Column
with gr.Column():
sd_output_image = \
gr.Gallery(show_label=False, elem_id='output_gallery').style(grid=3)
sd_output_html = \
gr.HTML()
# Right Column
with gr.Column():
with gr.Row():
sd_sampling_method = \
gr.Dropdown(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index")
sd_sampling_steps = \
gr.Slider(label="Sampling steps", value=30, minimum=5, maximum=100, step=5)
with gr.Group():
sd_cfg = \
gr.Slider(label='Prompt similarity (CFG)', value=8.0, minimum=1.0, maximum=15.0, step=0.5)
sd_denoise = \
gr.Slider(label='Denoising Strength (DNS)', value=0.75, minimum=0.0, maximum=1.0, step=0.01, visible=False)
sd_use_input_seed = \
gr.Checkbox(label='Custom seed')
sd_input_seed = \
gr.Number(value=-1, visible=False, show_label=False)
sd_facefix = \
gr.Checkbox(label='GFPGAN', value=False, visible=GFPGAN is not None)
sd_facefix_strength = \
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None, visible=False)
sd_matrix = \
gr.Checkbox(label='Create prompt matrix', value=False)
sd_loopback = \
gr.Checkbox(label='Output loopback', value=False, visible=False)
sd_upscale = \
gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'):
# TODO: fix this
gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never")
def mode_change(mode: str):
is_img2img = (mode == 'Image-to-Image')
is_txt2img = (mode == 'Text-to-Image')
is_pp = (mode == 'Post-Processing')
return {
sd_cfg: gr.update(visible=is_img2img or is_txt2img),
sd_denoise: gr.update(visible=is_img2img),
sd_sampling_method: gr.update(visible=is_img2img or is_txt2img),
sd_sampling_steps: gr.update(visible=is_img2img or is_txt2img),
sd_batch_count: gr.update(visible=is_img2img or is_txt2img),
sd_batch_size: gr.update(visible=is_img2img or is_txt2img),
sd_input_image: gr.update(visible=is_img2img or is_pp),
sd_resize_mode: gr.update(visible=is_img2img),
sd_image_height: gr.update(visible=is_img2img or is_txt2img),
sd_image_width: gr.update(visible=is_img2img or is_txt2img),
sd_use_input_seed: gr.update(visible=is_img2img or is_txt2img),
# TODO: can we handle this by updating use_input_seed and having its callback handle it?
sd_input_seed: gr.update(visible=False),
sd_facefix: gr.update(visible=True),
# TODO: see above, but for facefix
sd_facefix_strength: gr.update(visible=False),
sd_matrix: gr.update(visible=is_img2img or is_txt2img),
sd_loopback: gr.update(visible=is_img2img),
sd_upscale: gr.update(visible=is_img2img)
}
sd_mode.change(
fn=mode_change,
inputs=sd_mode,
outputs=[
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
]
)
do_generate_args = dict(
fn=wrap_gradio_call(do_generate),
inputs=[
sd_mode,
sd_prompt,
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
],
outputs=[
sd_output_image,
sd_output_html
]
)
sd_prompt.submit(**do_generate_args)
sd_generate.click(**do_generate_args)
sd_use_input_seed.change(
lambda checked : gr.update(visible=checked),
inputs=sd_use_input_seed,
outputs=sd_input_seed
)
sd_image_height.submit(
lambda value : 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_height,
outputs=sd_image_height
)
sd_image_width.submit(
lambda value : 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_width,
outputs=sd_image_width
)
sd_batch_count.submit(
lambda value : value if value > 0 else 1,
inputs=sd_batch_count,
outputs=sd_batch_count
)
sd_batch_size.submit(
lambda value : value if value > 0 else 1,
inputs=sd_batch_size,
outputs=sd_batch_size
)
demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment