Commit d61e31ba authored by Robert Barron's avatar Robert Barron

Merge remote-tracking branch 'auto1111/dev' into shared-hires-prompt-test

parents 54f926b1 f3b96d49
...@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()
...@@ -133,7 +133,7 @@ class NetworkModule: ...@@ -133,7 +133,7 @@ class NetworkModule:
return 1.0 return 1.0
def finalize_updown(self, updown, orig_weight, output_shape): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
...@@ -145,7 +145,10 @@ class NetworkModule: ...@@ -145,7 +145,10 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel(): if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape) updown = updown.reshape(orig_weight.shape)
return updown * self.calc_scale() * self.multiplier() if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target): def calc_updown(self, target):
raise NotImplementedError() raise NotImplementedError()
......
import network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
This diff is collapsed.
...@@ -23,9 +23,9 @@ def unload(): ...@@ -23,9 +23,9 @@ def unload():
def before_ui(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_network = extra_networks_lora.ExtraNetworkLora() networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network) extra_networks.register_extra_network(networks.extra_network_lora)
extra_networks.register_extra_network_alias(extra_network, "lyco") extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'): if not hasattr(torch.nn, 'Linear_forward_before_network'):
...@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'): ...@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
...@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward ...@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
......
...@@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
item = { item = {
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata, "metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
......
...@@ -12,6 +12,7 @@ onUiLoaded(async() => { ...@@ -12,6 +12,7 @@ onUiLoaded(async() => {
"Sketch": elementIDs.sketch "Sketch": elementIDs.sketch
}; };
// Helper functions // Helper functions
// Get active tab // Get active tab
function getActiveTab(elements, all = false) { function getActiveTab(elements, all = false) {
...@@ -377,6 +378,11 @@ onUiLoaded(async() => { ...@@ -377,6 +378,11 @@ onUiLoaded(async() => {
toggleOverlap("off"); toggleOverlap("off");
fullScreenMode = false; fullScreenMode = false;
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
if (closeBtn) {
closeBtn.addEventListener("click", resetZoom);
}
if ( if (
canvas && canvas &&
parseFloat(canvas.style.width) > 865 && parseFloat(canvas.style.width) > 865 &&
...@@ -657,17 +663,20 @@ onUiLoaded(async() => { ...@@ -657,17 +663,20 @@ onUiLoaded(async() => {
// Simulation of the function to put a long image into the screen. // Simulation of the function to put a long image into the screen.
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
// We hide the image and show it to the user when it is ready. // We hide the image and show it to the user when it is ready.
function autoExpand(e) {
targetElement.isExpanded = false;
function autoExpand() {
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`); const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch; const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
if (canvas && isMainTab) { if (canvas && isMainTab) {
if (hasHorizontalScrollbar(targetElement)) { if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
targetElement.style.visibility = "hidden"; targetElement.style.visibility = "hidden";
setTimeout(() => { setTimeout(() => {
fitToScreen(); fitToScreen();
resetZoom(); resetZoom();
targetElement.style.visibility = "visible"; targetElement.style.visibility = "visible";
targetElement.isExpanded = true;
}, 10); }, 10);
} }
} }
...@@ -675,9 +684,24 @@ onUiLoaded(async() => { ...@@ -675,9 +684,24 @@ onUiLoaded(async() => {
targetElement.addEventListener("mousemove", getMousePosition); targetElement.addEventListener("mousemove", getMousePosition);
//observers
// Creating an observer with a callback function to handle DOM changes
const observer = new MutationObserver((mutationsList, observer) => {
for (let mutation of mutationsList) {
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
mutation.target.tagName.toLowerCase() === 'canvas') {
targetElement.isExpanded = false;
setTimeout(resetZoom, 10);
}
}
});
// Apply auto expand if enabled // Apply auto expand if enabled
if (hotkeysConfig.canvas_auto_expand) { if (hotkeysConfig.canvas_auto_expand) {
targetElement.addEventListener("mousemove", autoExpand); targetElement.addEventListener("mousemove", autoExpand);
// Set up an observer to track attribute changes
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
} }
// Handle events only inside the targetElement // Handle events only inside the targetElement
......
...@@ -50,10 +50,12 @@ class PydanticModelGenerator: ...@@ -50,10 +50,12 @@ class PydanticModelGenerator:
additional_fields = None, additional_fields = None,
): ):
def field_type_generator(k, v): def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation field_type = v.annotation
if field_type == 'Image':
# images are sent as base64 strings via API
field_type = 'str'
return Optional[field_type] return Optional[field_type]
def merge_class_params(class_): def merge_class_params(class_):
...@@ -63,7 +65,6 @@ class PydanticModelGenerator: ...@@ -63,7 +65,6 @@ class PydanticModelGenerator:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters} parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters return parameters
self._model_name = model_name self._model_name = model_name
self._class_data = merge_class_params(class_instance) self._class_data = merge_class_params(class_instance)
...@@ -72,7 +73,7 @@ class PydanticModelGenerator: ...@@ -72,7 +73,7 @@ class PydanticModelGenerator:
field=underscore(k), field=underscore(k),
field_alias=k, field_alias=k,
field_type=field_type_generator(k, v), field_type=field_type_generator(k, v),
field_value=v.default field_value=None if isinstance(v.default, property) else v.default
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
......
...@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p) process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
...@@ -166,12 +166,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -166,12 +166,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
styles=prompt_styles, styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sampler_name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
......
...@@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None): ...@@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
if current_hash == commithash: if current_hash == commithash:
return return
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
...@@ -243,7 +246,7 @@ def list_extensions(settings_file): ...@@ -243,7 +246,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none') disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none': if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions:
return [] return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
...@@ -319,12 +322,12 @@ def prepare_environment(): ...@@ -319,12 +322,12 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart")) os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING', '1') os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError: except OSError:
......
...@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps: if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
......
...@@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
shared.state.begin(job="extras") shared.state.begin(job="extras")
image_data = []
image_names = []
outputs = [] outputs = []
if extras_mode == 1: def get_images(extras_mode, image, image_folder, input_dir):
for img in image_folder: if extras_mode == 1:
if isinstance(img, Image.Image): for img in image_folder:
image = img if isinstance(img, Image.Image):
fn = '' image = img
else: fn = ''
image = Image.open(os.path.abspath(img.name)) else:
fn = os.path.splitext(img.orig_name)[0] image = Image.open(os.path.abspath(img.name))
image_data.append(image) fn = os.path.splitext(img.orig_name)[0]
image_names.append(fn) yield image, fn
elif extras_mode == 2: elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected' assert input_dir, 'input directory not selected'
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: try:
image = Image.open(filename) image = Image.open(filename)
except Exception: except Exception:
continue continue
image_data.append(image) yield image, filename
image_names.append(filename) else:
else: assert image, 'image not selected'
assert image, 'image not selected' yield image, None
image_data.append(image)
image_names.append(None)
if extras_mode == 2 and output_dir != '': if extras_mode == 2 and output_dir != '':
outpath = output_dir outpath = output_dir
...@@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image, name in zip(image_data, image_names): for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
image_data: Image.Image
shared.state.textinfo = name shared.state.textinfo = name
parameters, existing_pnginfo = images.read_info_from_image(image) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(pp, args)
...@@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode != 2 or show_extras_results: if extras_mode != 2 or show_extras_results:
outputs.append(pp.image) outputs.append(pp.image)
image_data.close()
devices.torch_gc() devices.torch_gc()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''
......
This diff is collapsed.
import gradio as gr
from modules import scripts, sd_models
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
class ScriptRefiner(scripts.Script):
section = "accordions"
create_group = False
def __init__(self):
pass
def title(self):
return "Refiner"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
with gr.Row():
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
def lookup_checkpoint(title):
info = sd_models.get_closet_checkpoint_match(title)
return None if info is None else info.title
self.infotext_fields = [
(enable_refiner, lambda d: 'Refiner' in d),
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
(refiner_switch_at, 'Refiner switch at'),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
p.refiner_checkpoint_info = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
p.refiner_switch_at = refiner_switch_at
import json
import gradio as gr
from modules import scripts, ui, errors
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
class ScriptSeed(scripts.ScriptBuiltin):
section = "seed"
create_group = False
def __init__(self):
self.seed = None
self.reuse_seed = None
self.reuse_subseed = None
def title(self):
return "Seed"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Row(elem_id=self.elem_id("seed_row")):
if cmd_opts.use_textbox_seed:
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
with gr.Row(elem_id=self.elem_id("subseed_row")):
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
(self.seed, "Seed"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
p.seed = seed
if seed_checkbox and subseed_strength > 0:
p.subseed = subseed
p.subseed_strength = subseed_strength
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
p.seed_resize_from_w = seed_resize_from_w
p.seed_resize_from_h = seed_resize_from_h
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
def copy_seed(gen_info_string: str, index):
res = -1
try:
gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0)
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
all_subseeds = gen_info.get('all_subseeds', [-1])
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr.update()]
reuse_seed.click(
fn=copy_seed,
_js="(x, y) => [x, selected_gallery_index()]",
show_progress=False,
inputs=[generation_info, seed],
outputs=[seed, seed]
)
This diff is collapsed.
from __future__ import annotations from __future__ import annotations
import math import math
import psutil import psutil
import platform
import torch import torch
from torch import einsum from torch import einsum
...@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): ...@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization): class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic" name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention" cmd_opt = "opt_sub_quad_attention"
priority = 10
@property
def priority(self):
return 1000 if shared.device.type == 'mps' else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
...@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): ...@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property @property
def priority(self): def priority(self):
return 1000 if not torch.cuda.is_available() else 10 return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
...@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ...@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None: if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) if q.device.type == 'mps':
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
else:
chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0: elif chunk_threshold == 0:
chunk_threshold_bytes = None chunk_threshold_bytes = None
else: else:
......
...@@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") ...@@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def get_closet_checkpoint_match(search_string): def get_closet_checkpoint_match(search_string):
if not search_string:
return None
checkpoint_info = checkpoint_aliases.get(search_string, None) checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
......
...@@ -45,18 +45,23 @@ class CFGDenoiser(torch.nn.Module): ...@@ -45,18 +45,23 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
self.steps = None self.steps = None
"""number of steps as specified by user in UI"""
self.total_steps = None
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
self.step = 0 self.step = 0
self.image_cfg_scale = None self.image_cfg_scale = None
self.padded_cond_uncond = False self.padded_cond_uncond = False
self.sampler = sampler self.sampler = sampler
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
self.mask_before_denoising = False
@property @property
def inner_model(self): def inner_model(self):
raise NotImplementedError() raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale): def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond) denoised = torch.clone(denoised_uncond)
...@@ -100,7 +105,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -100,7 +105,7 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
if self.mask is not None: if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list) batch_size = len(conds_list)
...@@ -202,6 +207,9 @@ class CFGDenoiser(torch.nn.Module): ...@@ -202,6 +207,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt": if opts.live_preview_content == "Prompt":
......
...@@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s ...@@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s
from modules.shared import opts, state from modules.shared import opts, state
import k_diffusion.sampling import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
class SamplerData(SamplerDataTuple):
def total_steps(self, steps):
if self.options.get("second_order", False):
steps = steps * 2
return steps
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
...@@ -83,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None): ...@@ -83,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae) image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1 image = image * 2 - 1
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(
model.encode_first_stage(torch.unsqueeze(img, 0))
)[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent return x_latent
...@@ -131,31 +148,29 @@ def replace_torchsde_browinan(): ...@@ -131,31 +148,29 @@ def replace_torchsde_browinan():
replace_torchsde_browinan() replace_torchsde_browinan()
def apply_refiner(sampler): def apply_refiner(cfg_denoiser):
completed_ratio = sampler.step / sampler.steps completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if completed_ratio <= shared.opts.sd_refiner_switch_at: if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False return False
if shared.opts.sd_refiner_checkpoint == "None": if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False return False
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
return False return False
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
if refiner_checkpoint_info is None: cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info) sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc() devices.torch_gc()
sampler.p.setup_conds() cfg_denoiser.p.setup_conds()
sampler.update_inner_model() cfg_denoiser.update_inner_model()
return True return True
...@@ -192,7 +207,7 @@ class Sampler: ...@@ -192,7 +207,7 @@ class Sampler:
self.sampler_noises = None self.sampler_noises = None
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.config = None # set by the function calling the constructor self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None self.last_latent = None
self.s_min_uncond = None self.s_min_uncond = None
self.s_churn = 0.0 self.s_churn = 0.0
...@@ -208,6 +223,7 @@ class Sampler: ...@@ -208,6 +223,7 @@ class Sampler:
self.p = None self.p = None
self.model_wrap_cfg = None self.model_wrap_cfg = None
self.sampler_extra_args = None self.sampler_extra_args = None
self.options = {}
def callback_state(self, d): def callback_state(self, d):
step = d['i'] step = d['i']
...@@ -220,6 +236,7 @@ class Sampler: ...@@ -220,6 +236,7 @@ class Sampler:
def launch_sampling(self, steps, func): def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps self.model_wrap_cfg.steps = steps
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps state.sampling_steps = steps
state.sampling_step = 0 state.sampling_step = 0
...@@ -267,19 +284,19 @@ class Sampler: ...@@ -267,19 +284,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise) s_noise = getattr(opts, 's_noise', p.s_noise)
if s_churn != self.s_churn: if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin: if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax: if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise: if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise p.extra_generation_params['Sigma noise'] = s_noise
...@@ -296,5 +313,8 @@ class Sampler: ...@@ -296,5 +313,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
...@@ -22,6 +22,12 @@ samplers_k_diffusion = [ ...@@ -22,6 +22,12 @@ samplers_k_diffusion = [
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
...@@ -42,6 +48,12 @@ sampler_extra_params = { ...@@ -42,6 +48,12 @@ sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_fast': ['s_noise'],
'sample_dpm_2_ancestral': ['s_noise'],
'sample_dpmpp_2s_ancestral': ['s_noise'],
'sample_dpmpp_sde': ['s_noise'],
'sample_dpmpp_2m_sde': ['s_noise'],
'sample_dpmpp_3m_sde': ['s_noise'],
} }
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
...@@ -64,9 +76,12 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): ...@@ -64,9 +76,12 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
class KDiffusionSampler(sd_samplers_common.Sampler): class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname) super().__init__(funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiserKDiffusion(self) self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
...@@ -149,6 +164,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -149,6 +164,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
noise_sampler = self.create_noise_sampler(x, sigmas, p) noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
self.sampler_extra_args = { self.sampler_extra_args = {
...@@ -190,6 +208,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -190,6 +208,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
noise_sampler = self.create_noise_sampler(x, sigmas, p) noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.last_latent = x self.last_latent = x
self.sampler_extra_args = { self.sampler_extra_args = {
'cond': conditioning, 'cond': conditioning,
...@@ -198,6 +219,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -198,6 +219,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
} }
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond: if self.model_wrap_cfg.padded_cond_uncond:
......
...@@ -49,12 +49,12 @@ class CFGDenoiserTimesteps(CFGDenoiser): ...@@ -49,12 +49,12 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler) super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
def get_pred_x0(self, x_in, x_out, sigma): def get_pred_x0(self, x_in, x_out, sigma):
ts = int(sigma.item()) ts = sigma.to(dtype=int)
s_in = x_in.new_ones([x_in.shape[0]]) a_t = self.alphas[ts][:, None, None, None]
a_t = self.alphas[ts].item() * s_in
sqrt_one_minus_at = (1 - a_t).sqrt() sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
......
...@@ -11,21 +11,22 @@ from modules.models.diffusion.uni_pc import uni_pc ...@@ -11,21 +11,22 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable): for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args) e_t = model(x, timesteps[index].item() * s_in, **extra_args)
a_t = alphas[index].item() * s_in a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_in a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_in sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
...@@ -42,18 +43,19 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= ...@@ -42,18 +43,19 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = [] old_eps = []
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_in a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_in a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
......
...@@ -31,7 +31,9 @@ def get_loaded_vae_hash(): ...@@ -31,7 +31,9 @@ def get_loaded_vae_hash():
if loaded_vae_file is None: if loaded_vae_file is None:
return None return None
return hashes.sha256(loaded_vae_file, 'vae')[0:10] sha256 = hashes.sha256(loaded_vae_file, 'vae')
return sha256[0:10] if sha256 else None
def get_base_vae(model): def get_base_vae(model):
......
...@@ -69,10 +69,11 @@ def reload_hypernetworks(): ...@@ -69,10 +69,11 @@ def reload_hypernetworks():
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"inpaint", "inpaint",
"sampler", "sampler",
"accordions",
"checkboxes", "checkboxes",
"hires_fix",
"dimensions", "dimensions",
"cfg", "cfg",
"denoising",
"seed", "seed",
"batch", "batch",
"override_settings", "override_settings",
...@@ -86,7 +87,7 @@ def ui_reorder_categories(): ...@@ -86,7 +87,7 @@ def ui_reorder_categories():
sections = {} sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
if isinstance(script.section, str): if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
sections[script.section] = 1 sections[script.section] = 1
yield from sections yield from sections
......
...@@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
"sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
...@@ -288,12 +286,12 @@ options_templates.update(options_section(('ui', "Live previews"), { ...@@ -288,12 +286,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
......
...@@ -58,7 +58,7 @@ def _summarize_chunk( ...@@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float, scale: float,
) -> AttnChunk: ) -> AttnChunk:
attn_weights = torch.baddbmm( attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,
...@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking( ...@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float, scale: float,
) -> Tensor: ) -> Tensor:
attn_scores = torch.baddbmm( attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,
......
...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html ...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
...@@ -19,12 +19,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -19,12 +19,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
prompt=prompt, prompt=prompt,
styles=prompt_styles, styles=prompt_styles,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sampler_name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
......
This diff is collapsed.
...@@ -137,13 +137,17 @@ Requested path was: {f} ...@@ -137,13 +137,17 @@ Requested path was: {f}
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config) open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) buttons = {
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
}
open_folder_button.click( open_folder_button.click(
fn=lambda: open_folder(shared.opts.outdir_samples or outdir), fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
......
...@@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox): ...@@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox):
self.accordion_id = f"input-accordion-{InputAccordion.global_index}" self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1 InputAccordion.global_index += 1
kwargs['elem_id'] = self.accordion_id + "-checkbox" kwargs_checkbox = {
kwargs['visible'] = False **kwargs,
super().__init__(value, **kwargs) "elem_id": f"{self.accordion_id}-checkbox",
"visible": False,
}
super().__init__(value, **kwargs_checkbox)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self]) self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion']) kwargs_accordion = {
**kwargs,
"elem_id": self.accordion_id,
"label": kwargs.get('label', 'Accordion'),
"elem_classes": ['input-accordion'],
"open": value,
}
self.accordion = gr.Accordion(**kwargs_accordion)
def extra(self): def extra(self):
"""Allows you to put something into the label of the accordion. """Allows you to put something into the label of the accordion.
......
...@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": checkpoint.filename, "filename": checkpoint.filename,
"shorthash": checkpoint.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
from modules import shared, ui_extra_networks from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js from modules.ui_extra_networks import quote_js
from modules.hashes import sha256_from_cache
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
...@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name] full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path) path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
return { return {
"name": name, "name": name,
"filename": full_path, "filename": full_path,
"shorthash": shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(path), "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"), "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
......
...@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": name, "name": name,
"filename": embedding.filename, "filename": embedding.filename,
"shorthash": embedding.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename), "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
"prompt": quote_js(embedding.name), "prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
......
...@@ -93,11 +93,13 @@ class UserMetadataEditor: ...@@ -93,11 +93,13 @@ class UserMetadataEditor:
item = self.page.items.get(name, {}) item = self.page.items.get(name, {})
try: try:
filename = item["filename"] filename = item["filename"]
shorthash = item.get("shorthash", None)
stats = os.stat(filename) stats = os.stat(filename)
params = [ params = [
('Filename: ', os.path.basename(filename)), ('Filename: ', os.path.basename(filename)),
('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
('Hash: ', shorthash),
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
] ]
...@@ -115,7 +117,7 @@ class UserMetadataEditor: ...@@ -115,7 +117,7 @@ class UserMetadataEditor:
errors.display(e, f"reading metadata info for {name}") errors.display(e, f"reading metadata info for {name}")
params = [] params = []
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>' table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
......
...@@ -48,13 +48,13 @@ class UiLoadsave: ...@@ -48,13 +48,13 @@ class UiLoadsave:
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
pass pass
else: else:
if isinstance(x, gr.Textbox) and field == 'value': # due to an undersirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value) saved_value = str(saved_value)
elif isinstance(x, gr.Number) and field == 'value': elif isinstance(x, gr.Number) and field == 'value':
try: try:
saved_value = float(saved_value) saved_value = float(saved_value)
except ValueError: except ValueError:
saved_value = -1 return
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None: if init_field is not None:
......
This diff is collapsed.
...@@ -166,16 +166,6 @@ a{ ...@@ -166,16 +166,6 @@ a{
color: var(--button-secondary-text-color-hover); color: var(--button-secondary-text-color-hover);
} }
.checkboxes-row{
margin-bottom: 0.5em;
margin-left: 0em;
}
.checkboxes-row > div{
flex: 0;
white-space: nowrap;
min-width: auto !important;
}
button.custom-button{ button.custom-button{
border-radius: var(--button-large-radius); border-radius: var(--button-large-radius);
padding: var(--button-large-padding); padding: var(--button-large-padding);
...@@ -192,7 +182,7 @@ button.custom-button{ ...@@ -192,7 +182,7 @@ button.custom-button{
text-align: center; text-align: center;
} }
div.gradio-accordion { div.block.gradio-accordion {
border: 1px solid var(--block-border-color) !important; border: 1px solid var(--block-border-color) !important;
border-radius: 8px !important; border-radius: 8px !important;
margin: 2px 0; margin: 2px 0;
...@@ -239,10 +229,14 @@ div.gradio-accordion { ...@@ -239,10 +229,14 @@ div.gradio-accordion {
} }
[id$=_subseed_show] label{ [id$=_subseed_show] label{
margin-bottom: 0.5em; margin-bottom: 0.65em;
align-self: end; align-self: end;
} }
[id$=_seed_extras] > div{
gap: 0.5em;
}
.html-log .comments{ .html-log .comments{
padding-top: 0.5em; padding-top: 0.5em;
} }
...@@ -352,7 +346,7 @@ div.gradio-accordion { ...@@ -352,7 +346,7 @@ div.gradio-accordion {
} }
div.dimensions-tools{ div.dimensions-tools{
min-width: 0 !important; min-width: 1.6em !important;
max-width: fit-content; max-width: fit-content;
flex-direction: column; flex-direction: column;
place-content: center; place-content: center;
...@@ -369,8 +363,8 @@ div#extras_scale_to_tab div.form{ ...@@ -369,8 +363,8 @@ div#extras_scale_to_tab div.form{
z-index: 5; z-index: 5;
} }
.image-buttons button{ .image-buttons > .form{
min-width: auto; justify-content: center;
} }
.infotext { .infotext {
...@@ -391,19 +385,21 @@ div#extras_scale_to_tab div.form{ ...@@ -391,19 +385,21 @@ div#extras_scale_to_tab div.form{
/* settings */ /* settings */
#quicksettings { #quicksettings {
width: fit-content;
align-items: end; align-items: end;
} }
#quicksettings > div, #quicksettings > fieldset{ #quicksettings > div, #quicksettings > fieldset{
max-width: 24em; max-width: 36em;
min-width: 24em; width: fit-content;
width: 24em; flex: 0 1 fit-content;
padding: 0; padding: 0;
border: none; border: none;
box-shadow: none; box-shadow: none;
background: none; background: none;
} }
#quicksettings > div.gradio-dropdown{
min-width: 24em !important;
}
#settings{ #settings{
display: block; display: block;
...@@ -1012,10 +1008,29 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi ...@@ -1012,10 +1008,29 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi
} }
div.block.input-accordion{ div.block.input-accordion{
margin-bottom: 0.4em;
} }
.input-accordion-extra{ .input-accordion-extra{
flex: 0 0 auto !important; flex: 0 0 auto !important;
margin: 0 0.5em 0 auto; margin: 0 0.5em 0 auto;
} }
div.accordions > div.input-accordion{
min-width: fit-content !important;
}
div.accordions > div.gradio-accordion .label-wrap span{
white-space: nowrap;
margin-right: 0.25em;
}
div.accordions{
gap: 0.5em;
}
div.accordions > div.input-accordion.input-accordion-open{
flex: 1 auto;
flex-flow: column;
}
...@@ -12,8 +12,6 @@ fi ...@@ -12,8 +12,6 @@ fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1
#################################################################### ####################################################################
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment