Commit 35ac3a66 authored by deggua's avatar deggua

Improved error messages, formatting, and fixed the settings menu

parent aa74ac7e
import argparse
from doctest import OutputChecker
import os
from pydoc import visiblename
import sys
from collections import namedtuple
import torch
......@@ -18,6 +16,7 @@ import html
import time
import json
import traceback
from datetime import datetime
import k_diffusion.sampling
from ldm.util import instantiate_from_config
......@@ -215,8 +214,11 @@ def sanitize_filename_part(text):
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
def plaintext_to_html(text):
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
def plaintext_to_html(text, klass=None):
if klass is None:
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
else:
text = "".join([f"<p class=\"{klass}\">{html.escape(x)}</p>\n" for x in text.split('\n')])
return text
......@@ -701,6 +703,22 @@ class KDiffusionSampler:
return samples_ddim
class OutputInfo:
def __init__(self, prompt: str, params: str, comments: str):
self.prompt = prompt.strip()
self.params = params.strip()
self.comments = comments.strip()
def __str__(self):
return '\n'.join([self.prompt, self.params, self.comments])
def html(self) -> str:
return f'''
{plaintext_to_html(self.prompt, "prompt-info")}<br>
{plaintext_to_html(self.params, "params-info")}
{plaintext_to_html(self.comments, "comments-info")}
'''
def process_images(p: StableDiffusionProcessing):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
......@@ -758,7 +776,7 @@ def process_images(p: StableDiffusionProcessing):
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
def infotext():
return f'{prompt}\n<p class="performance">{generation_params_text}</p>'.strip() + "".join(["\n\n" + x for x in comments])
return OutputInfo(prompt, generation_params_text, "".join(["\n\n" + x for x in comments]))
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
......@@ -803,7 +821,7 @@ def process_images(p: StableDiffusionProcessing):
else:
output_image = Image.fromarray(x_sample)
save_image(output_image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
save_image(output_image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=str(infotext()))
output_images.append(output_image)
base_count += 1
......@@ -822,7 +840,7 @@ def process_images(p: StableDiffusionProcessing):
else:
grid = image_grid(output_images, p.batch_size)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=str(infotext()), short_filename=not opts.grid_extended_filename)
grid_count += 1
torch_gc()
......@@ -861,7 +879,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
output_images, seed, info = process_images(p)
return output_images, plaintext_to_html(info)
return output_images, info.html()
class Flagging(gr.FlaggingCallback):
......@@ -919,6 +937,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
imgs = []
if not self.init_images or None in self.init_images:
raise Exception('No input image provided for Image-to-Image')
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
......@@ -1000,7 +1022,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(info), short_filename=not opts.grid_extended_filename)
output_images = history
seed = initial_seed
......@@ -1049,7 +1071,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
combined_image = combine_grid(grid)
grid_count = len(os.listdir(outpath)) - 1
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(initial_info), short_filename=not opts.grid_extended_filename)
output_images = [combined_image]
seed = initial_seed
......@@ -1058,7 +1080,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
else:
output_images, seed, info = process_images(p)
return output_images, plaintext_to_html(info)
return output_images, info.html()
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
info = realesrgan_models[RealESRGAN_model_index]
......@@ -1080,6 +1102,9 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
torch_gc()
if not image:
raise Exception('No input image provided for Post-Processing')
image = image.convert("RGB")
outpath = opts.outdir or "outputs/extras-samples"
......@@ -1111,15 +1136,12 @@ if os.path.exists(config_filename):
def run_settings(*args):
up = []
for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
for key, value in zip(opts.data_labels.keys(), args):
opts.data[key] = value
up.append(comp.update(value=value))
opts.save(config_filename)
return 'Settings saved.', ''
return plaintext_to_html(f'Settings saved @ {datetime.now().strftime("%I:%M:%S")}')
def create_setting_component(key):
......@@ -1227,8 +1249,8 @@ css_hide_progressbar = \
main_css = \
"""
.output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
.output-html p { margin: 0 0.5em; }
.performance, .params-info, .comments-info { font-size: 0.85em; color: #666; }
"""
# [data-testid="image"] {min-height: 512px !important}
......@@ -1383,17 +1405,12 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'):
# TODO: fix this
gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never")
# TODO: Add HTML output to indicate settings saved
sd_settings = [create_setting_component(key) for key in opts.data_labels.keys()]
sd_save_settings = \
gr.Button('Save')
sd_confirm_settings = \
gr.HTML()
def mode_change(mode: str, facefix: bool, custom_seed: bool):
is_img2img = (mode == 'Image-to-Image')
......@@ -1518,4 +1535,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
outputs=sd_facefix_strength
)
sd_save_settings.click(
fn=run_settings,
inputs=sd_settings,
outputs=sd_confirm_settings
)
demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment