Commit 3ec2eb8b authored by InvincibleDude's avatar InvincibleDude Committed by GitHub

Merge branch 'master' into improved-hr-conflict-test

parents 0d834b93 ee9fdf7f
...@@ -20,8 +20,7 @@ model: ...@@ -20,8 +20,7 @@ model:
conditioning_key: hybrid conditioning_key: hybrid
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: true use_ema: false
load_ema: true
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
......
...@@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) { ...@@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt] return [prompt, negative_prompt]
} }
promptTokecountUpdateFuncs = {}
function recalculatePromptTokens(name){
if(promptTokecountUpdateFuncs[name]){
promptTokecountUpdateFuncs[name]()
}
}
function recalculate_prompts_txt2img(){
recalculatePromptTokens('txt2img_prompt')
recalculatePromptTokens('txt2img_neg_prompt')
return args_to_array(arguments);
}
function recalculate_prompts_img2img(){
recalculatePromptTokens('img2img_prompt')
recalculatePromptTokens('img2img_neg_prompt')
return args_to_array(arguments);
}
opts = {} opts = {}
onUiUpdate(function(){ onUiUpdate(function(){
if(Object.keys(opts).length != 0) return; if(Object.keys(opts).length != 0) return;
...@@ -232,14 +254,12 @@ onUiUpdate(function(){ ...@@ -232,14 +254,12 @@ onUiUpdate(function(){
return return
} }
prompt.parentElement.insertBefore(counter, prompt) prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter") counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative" prompt.parentElement.style.position = "relative"
textarea.addEventListener("input", function(){ promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
update_token_counter(id_button); textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
});
} }
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
...@@ -273,7 +293,7 @@ onOptionsChanged(function(){ ...@@ -273,7 +293,7 @@ onOptionsChanged(function(){
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeouts = {};
function update_txt2img_tokens(...args) { function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button") update_token_counter("txt2img_token_button")
...@@ -290,9 +310,9 @@ function update_img2img_tokens(...args) { ...@@ -290,9 +310,9 @@ function update_img2img_tokens(...args) {
} }
function update_token_counter(button_id) { function update_token_counter(button_id) {
if (token_timeout) if (token_timeouts[button_id])
clearTimeout(token_timeout); clearTimeout(token_timeouts[button_id]);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function restart_reload(){ function restart_reload(){
......
...@@ -223,6 +223,7 @@ def prepare_environment(): ...@@ -223,6 +223,7 @@ def prepare_environment():
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
...@@ -282,7 +283,7 @@ def prepare_environment(): ...@@ -282,7 +283,7 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
else: else:
print("Installation of xformers is not supported in this version of Python.") print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
......
import base64 import base64
import html
import io import io
import math import math
import os import os
...@@ -16,13 +17,23 @@ re_param = re.compile(re_param_code) ...@@ -16,13 +17,23 @@ re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
paste_fields = {} paste_fields = {}
bind_list = [] registered_param_bindings = []
class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
def reset(): def reset():
paste_fields.clear() paste_fields.clear()
bind_list.clear()
def quote(text): def quote(text):
...@@ -74,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields): ...@@ -74,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields modules.ui.img2img_paste_fields = fields
def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
'initial_noise_multiplier': 'Noise multiplier',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
for tabname, info in paste_fields.items():
if info["fields"] is not None:
info["fields"] += settings_paste_fields
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
...@@ -101,71 +92,79 @@ def create_buttons(tabs_list): ...@@ -101,71 +92,79 @@ def create_buttons(tabs_list):
return buttons return buttons
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info): def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info]) """old function for backwards compatibility; do not use this, use register_paste_params_button"""
for tabname, button in buttons.items():
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image): def register_paste_params_button(binding: ParamBinding):
w = img.width registered_param_bindings.append(binding)
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind(): def connect_paste_params_buttons():
for buttons, source_image_component, send_generate_info in bind_list: binding: ParamBinding
for tab in buttons: for binding in registered_param_bindings:
button = buttons[tab] destination_image_component = paste_fields[binding.tabname]["init_img"]
destination_image_component = paste_fields[tab]["init_img"] fields = paste_fields[binding.tabname]["fields"]
fields = paste_fields[tab]["fields"]
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
if source_image_component and destination_image_component: if binding.source_image_component and destination_image_component:
if isinstance(source_image_component, gr.Gallery): if isinstance(binding.source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery" jsfunc = "extract_image_from_gallery"
else: else:
func = send_image_and_dimensions if destination_width_component else lambda x: x func = send_image_and_dimensions if destination_width_component else lambda x: x
jsfunc = None jsfunc = None
button.click( binding.paste_button.click(
fn=func, fn=func,
_js=jsfunc, _js=jsfunc,
inputs=[source_image_component], inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
) )
if send_generate_info and fields is not None: if binding.source_text_component is not None and fields is not None:
if send_generate_info in paste_fields: connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click( binding.paste_button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
) )
else:
connect_paste(button, fields, send_generate_info)
button.click( binding.paste_button.click(
fn=None, fn=None,
_js=f"switch_to_{tab}", _js=f"switch_to_{binding.tabname}",
inputs=None, inputs=None,
outputs=None, outputs=None,
) )
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def find_hypernetwork_key(hypernet_name, hypernet_hash=None): def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext. """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
...@@ -290,7 +289,50 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -290,7 +289,50 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res return res
def connect_paste(button, paste_fields, input_comp, jsfunc=None): settings_map = {}
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
]
def create_override_settings_dict(text_pairs):
"""creates processing's override_settings parameters from gradio's multiselect
Example input:
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
Example output:
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
"""
res = {}
params = {}
for pair in text_pairs:
k, v = pair.split(":", maxsplit=1)
params[k] = v.strip()
for param_name, setting_name in infotext_to_setting_name_mapping:
value = params.get(param_name, None)
if value is None:
continue
res[setting_name] = shared.opts.cast_value(setting_name, value)
return res
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(data_path, "params.txt") filename = os.path.join(data_path, "params.txt")
...@@ -327,9 +369,35 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): ...@@ -327,9 +369,35 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
return res return res
if override_settings_component is not None:
def paste_settings(params):
vals = {}
for param_name, setting_name in infotext_to_setting_name_mapping:
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
vals[param_name] = v
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
button.click( button.click(
fn=paste_func, fn=paste_func,
_js=jsfunc, _js=f"recalculate_prompts_{tabname}",
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
) )
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices, sd_samplers from modules import devices, sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -75,7 +76,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -75,7 +76,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
if mode == 0: # img2img if mode == 0: # img2img
...@@ -142,6 +145,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -142,6 +145,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpaint_full_res=inpaint_full_res, inpaint_full_res=inpaint_full_res,
inpaint_full_res_padding=inpaint_full_res_padding, inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
override_settings=override_settings,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
......
...@@ -455,7 +455,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -455,7 +455,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
} }
......
This diff is collapsed.
from collections import namedtuple
import numpy as np
import torch
from PIL import Image
import torchsde._brownian.brownian_interval
from modules import devices, processing, images, sd_vae_approx
from modules.shared import opts, state
import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
requested_steps = (steps or p.steps)
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
t_enc = requested_steps - 1
else:
steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps)
return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded):
state.current_latent = decoded
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.assign_current_image(sample_to_image(decoded))
class InterruptedException(BaseException):
pass
# MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
else:
generator = torch.Generator(device).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=device, generator=generator)
torchsde._brownian.brownian_interval._randn = torchsde_randn
import math
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import numpy as np
import torch
from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
self.sampler_noises = None
self.step = 0
self.stop_at = None
self.eta = None
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except sd_samplers_common.InterruptedException:
return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
# filling unconditional_conditioning with repeats of the last vector to match length is
# not 100% correct but should work well enough
if unconditional_conditioning.shape[1] < cond.shape[1]:
last_vector = unconditional_conditioning[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
elif unconditional_conditioning.shape[1] > cond.shape[1]:
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
else:
self.last_latent = res[1]
sd_samplers_common.store_latent(self.last_latent)
self.step += 1
state.sampling_step = self.step
shared.total_tqdm.update()
return res
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
self.last_latent = x
self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
self.last_latent = x
self.step = 0
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
This diff is collapsed.
...@@ -105,6 +105,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ ...@@ -105,6 +105,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_dir, parser)
...@@ -127,12 +129,13 @@ restricted_opts = { ...@@ -127,12 +129,13 @@ restricted_opts = {
ui_reorder_categories = [ ui_reorder_categories = [
"inpaint", "inpaint",
"sampler", "sampler",
"checkboxes",
"hires_fix",
"dimensions", "dimensions",
"cfg", "cfg",
"seed", "seed",
"checkboxes",
"hires_fix",
"batch", "batch",
"override_settings",
"scripts", "scripts",
] ]
...@@ -346,10 +349,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { ...@@ -346,10 +349,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
})) }))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs), "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
})) }))
...@@ -440,7 +443,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -440,7 +443,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
...@@ -605,11 +608,37 @@ class Options: ...@@ -605,11 +608,37 @@ class Options:
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent" latent_upscale_default_mode = "Latent"
latent_upscale_modes = { latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False}, "Latent": {"mode": "bilinear", "antialias": False},
......
import modules.scripts import modules.scripts
from modules import sd_samplers from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -8,7 +9,9 @@ import modules.processing as processing ...@@ -8,7 +9,9 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -40,7 +43,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -40,7 +43,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_y=hr_resize_y, hr_resize_y=hr_resize_y,
hr_sampler=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else '---', hr_sampler=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else '---',
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
......
...@@ -380,6 +380,7 @@ def apply_setting(key, value): ...@@ -380,6 +380,7 @@ def apply_setting(key, value):
opts.save(shared.config_filename) opts.save(shared.config_filename)
return getattr(opts, key) return getattr(opts, key)
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()
...@@ -433,6 +434,18 @@ def get_value_for_setting(key): ...@@ -433,6 +434,18 @@ def get_value_for_setting(key):
return gr.update(value=value, **args) return gr.update(value=value, **args)
def create_override_settings_dropdown(tabname, row):
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
dropdown.change(
fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
inputs=[dropdown],
outputs=[dropdown],
)
return dropdown
def create_ui(): def create_ui():
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
...@@ -514,6 +527,10 @@ def create_ui(): ...@@ -514,6 +527,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "override_settings":
with FormRow(elem_id="txt2img_override_settings_row") as row:
override_settings = create_override_settings_dropdown('txt2img', row)
elif category == "scripts": elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"): with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui() custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
...@@ -535,7 +552,6 @@ def create_ui(): ...@@ -535,7 +552,6 @@ def create_ui():
) )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
...@@ -569,6 +585,8 @@ def create_ui(): ...@@ -569,6 +585,8 @@ def create_ui():
hr_sampler_index, hr_sampler_index,
hr_prompt, hr_prompt,
hr_negative_prompt, hr_negative_prompt,
override_settings,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
...@@ -632,6 +650,9 @@ def create_ui(): ...@@ -632,6 +650,9 @@ def create_ui():
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
))
txt2img_preview_params = [ txt2img_preview_params = [
txt2img_prompt, txt2img_prompt,
...@@ -779,6 +800,10 @@ def create_ui(): ...@@ -779,6 +800,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "override_settings":
with FormRow(elem_id="img2img_override_settings_row") as row:
override_settings = create_override_settings_dropdown('img2img', row)
elif category == "scripts": elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"): with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui() custom_inputs = modules.scripts.scripts_img2img.setup_ui()
...@@ -813,7 +838,6 @@ def create_ui(): ...@@ -813,7 +838,6 @@ def create_ui():
) )
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
...@@ -866,7 +890,8 @@ def create_ui(): ...@@ -866,7 +890,8 @@ def create_ui():
inpainting_mask_invert, inpainting_mask_invert,
img2img_batch_input_dir, img2img_batch_input_dir,
img2img_batch_output_dir, img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir img2img_batch_inpaint_mask_dir,
override_settings,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
...@@ -954,6 +979,9 @@ def create_ui(): ...@@ -954,6 +979,9 @@ def create_ui():
] ]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
))
modules.scripts.scripts_current = None modules.scripts.scripts_current = None
...@@ -971,7 +999,11 @@ def create_ui(): ...@@ -971,7 +999,11 @@ def create_ui():
html2 = gr.HTML() html2 = gr.HTML()
with gr.Row(): with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
parameters_copypaste.bind_buttons(buttons, image, generation_info)
for tabname, button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
))
image.change( image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo), fn=wrap_gradio_call(modules.extras.run_pnginfo),
...@@ -1380,6 +1412,7 @@ def create_ui(): ...@@ -1380,6 +1412,7 @@ def create_ui():
components = [] components = []
component_dict = {} component_dict = {}
shared.settings_components = component_dict
script_callbacks.ui_settings_callback() script_callbacks.ui_settings_callback()
opts.reorder() opts.reorder()
...@@ -1546,8 +1579,7 @@ def create_ui(): ...@@ -1546,8 +1579,7 @@ def create_ui():
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.connect_paste_params_buttons()
parameters_copypaste.run_bind()
with gr.Tabs(elem_id="tabs") as tabs: with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
......
...@@ -198,5 +198,9 @@ Requested path was: {f} ...@@ -198,5 +198,9 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}')
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
...@@ -14,6 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -14,6 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
shared.refresh_checkpoints() shared.refresh_checkpoints()
def list_items(self): def list_items(self):
checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items(): for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
previews = [path + ".png", path + ".preview.png"] previews = [path + ".png", path + ".preview.png"]
...@@ -28,7 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -28,7 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": path, "filename": path,
"preview": preview, "preview": preview,
"search_term": self.search_terms_from_path(checkpoint.filename), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": path + ".png", "local_preview": path + ".png",
} }
......
...@@ -52,6 +52,9 @@ else: ...@@ -52,6 +52,9 @@ else:
def check_versions(): def check_versions():
if shared.cmd_opts.skip_version_check:
return
expected_torch_version = "1.13.1" expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
...@@ -59,7 +62,10 @@ def check_versions(): ...@@ -59,7 +62,10 @@ def check_versions():
You are running torch {torch.__version__}. You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}. The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch. To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded. Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
expected_xformers_version = "0.0.16rc425" expected_xformers_version = "0.0.16rc425"
...@@ -71,6 +77,8 @@ Beware that this will cause a lot of large files to be downloaded. ...@@ -71,6 +77,8 @@ Beware that this will cause a lot of large files to be downloaded.
You are running xformers {xformers.__version__}. You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}. The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers. To reinstall the desired version, run with commandline flag --reinstall-xformers.
Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment