Commit 35ac3a66 authored by deggua's avatar deggua

Improved error messages, formatting, and fixed the settings menu

parent aa74ac7e
import argparse import argparse
from doctest import OutputChecker
import os import os
from pydoc import visiblename
import sys import sys
from collections import namedtuple from collections import namedtuple
import torch import torch
...@@ -18,6 +16,7 @@ import html ...@@ -18,6 +16,7 @@ import html
import time import time
import json import json
import traceback import traceback
from datetime import datetime
import k_diffusion.sampling import k_diffusion.sampling
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -215,8 +214,11 @@ def sanitize_filename_part(text): ...@@ -215,8 +214,11 @@ def sanitize_filename_part(text):
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128] return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
def plaintext_to_html(text): def plaintext_to_html(text, klass=None):
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')]) if klass is None:
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
else:
text = "".join([f"<p class=\"{klass}\">{html.escape(x)}</p>\n" for x in text.split('\n')])
return text return text
...@@ -701,6 +703,22 @@ class KDiffusionSampler: ...@@ -701,6 +703,22 @@ class KDiffusionSampler:
return samples_ddim return samples_ddim
class OutputInfo:
def __init__(self, prompt: str, params: str, comments: str):
self.prompt = prompt.strip()
self.params = params.strip()
self.comments = comments.strip()
def __str__(self):
return '\n'.join([self.prompt, self.params, self.comments])
def html(self) -> str:
return f'''
{plaintext_to_html(self.prompt, "prompt-info")}<br>
{plaintext_to_html(self.params, "params-info")}
{plaintext_to_html(self.comments, "comments-info")}
'''
def process_images(p: StableDiffusionProcessing): def process_images(p: StableDiffusionProcessing):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
...@@ -758,7 +776,7 @@ def process_images(p: StableDiffusionProcessing): ...@@ -758,7 +776,7 @@ def process_images(p: StableDiffusionProcessing):
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
def infotext(): def infotext():
return f'{prompt}\n<p class="performance">{generation_params_text}</p>'.strip() + "".join(["\n\n" + x for x in comments]) return OutputInfo(prompt, generation_params_text, "".join(["\n\n" + x for x in comments]))
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
...@@ -803,7 +821,7 @@ def process_images(p: StableDiffusionProcessing): ...@@ -803,7 +821,7 @@ def process_images(p: StableDiffusionProcessing):
else: else:
output_image = Image.fromarray(x_sample) output_image = Image.fromarray(x_sample)
save_image(output_image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext()) save_image(output_image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=str(infotext()))
output_images.append(output_image) output_images.append(output_image)
base_count += 1 base_count += 1
...@@ -822,7 +840,7 @@ def process_images(p: StableDiffusionProcessing): ...@@ -822,7 +840,7 @@ def process_images(p: StableDiffusionProcessing):
else: else:
grid = image_grid(output_images, p.batch_size) grid = image_grid(output_images, p.batch_size)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=str(infotext()), short_filename=not opts.grid_extended_filename)
grid_count += 1 grid_count += 1
torch_gc() torch_gc()
...@@ -861,7 +879,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, ...@@ -861,7 +879,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
output_images, seed, info = process_images(p) output_images, seed, info = process_images(p)
return output_images, plaintext_to_html(info) return output_images, info.html()
class Flagging(gr.FlaggingCallback): class Flagging(gr.FlaggingCallback):
...@@ -919,6 +937,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -919,6 +937,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.sampler = samplers_for_img2img[self.sampler_index].constructor() self.sampler = samplers_for_img2img[self.sampler_index].constructor()
imgs = [] imgs = []
if not self.init_images or None in self.init_images:
raise Exception('No input image provided for Image-to-Image')
for img in self.init_images: for img in self.init_images:
image = img.convert("RGB") image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height) image = resize_image(self.resize_mode, image, self.width, self.height)
...@@ -1000,7 +1022,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG ...@@ -1000,7 +1022,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1) grid = image_grid(history, batch_size, force_n_rows=1)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename) save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(info), short_filename=not opts.grid_extended_filename)
output_images = history output_images = history
seed = initial_seed seed = initial_seed
...@@ -1049,7 +1071,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG ...@@ -1049,7 +1071,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
combined_image = combine_grid(grid) combined_image = combine_grid(grid)
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename) save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(initial_info), short_filename=not opts.grid_extended_filename)
output_images = [combined_image] output_images = [combined_image]
seed = initial_seed seed = initial_seed
...@@ -1058,7 +1080,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG ...@@ -1058,7 +1080,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
else: else:
output_images, seed, info = process_images(p) output_images, seed, info = process_images(p)
return output_images, plaintext_to_html(info) return output_images, info.html()
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index): def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
info = realesrgan_models[RealESRGAN_model_index] info = realesrgan_models[RealESRGAN_model_index]
...@@ -1080,6 +1102,9 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) ...@@ -1080,6 +1102,9 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index): def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
torch_gc() torch_gc()
if not image:
raise Exception('No input image provided for Post-Processing')
image = image.convert("RGB") image = image.convert("RGB")
outpath = opts.outdir or "outputs/extras-samples" outpath = opts.outdir or "outputs/extras-samples"
...@@ -1111,15 +1136,12 @@ if os.path.exists(config_filename): ...@@ -1111,15 +1136,12 @@ if os.path.exists(config_filename):
def run_settings(*args): def run_settings(*args):
up = [] for key, value in zip(opts.data_labels.keys(), args):
for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
opts.data[key] = value opts.data[key] = value
up.append(comp.update(value=value))
opts.save(config_filename) opts.save(config_filename)
return 'Settings saved.', '' return plaintext_to_html(f'Settings saved @ {datetime.now().strftime("%I:%M:%S")}')
def create_setting_component(key): def create_setting_component(key):
...@@ -1227,8 +1249,8 @@ css_hide_progressbar = \ ...@@ -1227,8 +1249,8 @@ css_hide_progressbar = \
main_css = \ main_css = \
""" """
.output-html p {margin: 0 0.5em;} .output-html p { margin: 0 0.5em; }
.performance { font-size: 0.85em; color: #444; } .performance, .params-info, .comments-info { font-size: 0.85em; color: #666; }
""" """
# [data-testid="image"] {min-height: 512px !important} # [data-testid="image"] {min-height: 512px !important}
...@@ -1383,17 +1405,12 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1383,17 +1405,12 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
gr.Checkbox(label='Super resolution upscale', value=False, visible=False) gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'): with gr.TabItem('Settings', id='settings_tab'):
# TODO: fix this # TODO: Add HTML output to indicate settings saved
gr.Interface( sd_settings = [create_setting_component(key) for key in opts.data_labels.keys()]
run_settings, sd_save_settings = \
inputs=[create_setting_component(key) for key in opts.data_labels.keys()], gr.Button('Save')
outputs=[ sd_confirm_settings = \
gr.Textbox(label='Result'), gr.HTML()
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never")
def mode_change(mode: str, facefix: bool, custom_seed: bool): def mode_change(mode: str, facefix: bool, custom_seed: bool):
is_img2img = (mode == 'Image-to-Image') is_img2img = (mode == 'Image-to-Image')
...@@ -1518,4 +1535,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We ...@@ -1518,4 +1535,10 @@ with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion We
outputs=sd_facefix_strength outputs=sd_facefix_strength
) )
sd_save_settings.click(
fn=run_settings,
inputs=sd_settings,
outputs=sd_confirm_settings
)
demo.launch() demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment