Commit 4147fd6b authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into 10141-gradio-user-exif

parents f603275d bedcd2f3
...@@ -50,13 +50,14 @@ module.exports = { ...@@ -50,13 +50,14 @@ module.exports = {
globals: { globals: {
//script.js //script.js
gradioApp: "readonly", gradioApp: "readonly",
executeCallbacks: "readonly",
onAfterUiUpdate: "readonly",
onOptionsChanged: "readonly",
onUiLoaded: "readonly", onUiLoaded: "readonly",
onUiUpdate: "readonly", onUiUpdate: "readonly",
onOptionsChanged: "readonly",
uiCurrentTab: "writable", uiCurrentTab: "writable",
uiElementIsVisible: "readonly",
uiElementInSight: "readonly", uiElementInSight: "readonly",
executeCallbacks: "readonly", uiElementIsVisible: "readonly",
//ui.js //ui.js
opts: "writable", opts: "writable",
all_gallery_buttons: "readonly", all_gallery_buttons: "readonly",
...@@ -84,5 +85,7 @@ module.exports = { ...@@ -84,5 +85,7 @@ module.exports = {
// imageviewer.js // imageviewer.js
modalPrevImage: "readonly", modalPrevImage: "readonly",
modalNextImage: "readonly", modalNextImage: "readonly",
// token-counters.js
setupTokenCounters: "readonly",
} }
}; };
...@@ -43,8 +43,8 @@ body: ...@@ -43,8 +43,8 @@ body:
- type: input - type: input
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Version or Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
...@@ -80,6 +80,23 @@ body: ...@@ -80,6 +80,23 @@ body:
- AMD GPUs (RX 5000 below) - AMD GPUs (RX 5000 below)
- CPU - CPU
- Other GPUs - Other GPUs
- type: dropdown
id: cross_attention_opt
attributes:
label: Cross attention optimization
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
multiple: false
options:
- Automatic
- xformers
- sdp-no-mem
- sdp
- Doggettx
- V1
- InvokeAI
- "None "
validations:
required: true
- type: dropdown - type: dropdown
id: browsers id: browsers
attributes: attributes:
......
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage - name: Show coverage
run: | run: |
python -m coverage combine .coverage* python -m coverage combine .coverage*
......
## 1.4.0
### Features:
* zoom controls for inpainting
* run basic torch calculation at startup in parallel to reduce the performance impact of first generation
* option to pad prompt/neg prompt to be same length
* remove taming_transformers dependency
* custom k-diffusion scheduler settings
* add an option to show selected settings in main txt2img/img2img UI
* sysinfo tab in settings
* infer styles from prompts when pasting params into the UI
* an option to control the behavior of the above
### Minor:
* bump Gradio to 3.32.0
* bump xformers to 0.0.20
* Add option to disable token counters
* tooltip fixes & optimizations
* make it possible to configure filename for the zip download
* `[vae_filename]` pattern for filenames
* Revert discarding penultimate sigma for DPM-Solver++(2M) SDE
* change UI reorder setting to multiselect
* read version info form CHANGELOG.md if git version info is not available
* link footer API to Wiki when API is not active
* persistent conds cache (opt-in optimization)
### Extensions:
* After installing extensions, webui properly restarts the process rather than reloads the UI
* Added VAE listing to web API. Via: /sdapi/v1/sd-vae
* custom unet support
* Add onAfterUiUpdate callback
* refactor EmbeddingDatabase.register_embedding() to allow unregistering
* add before_process callback for scripts
* add ability for alwayson scripts to specify section and let user reorder those sections
### Bug Fixes:
* Fix dragging text to prompt
* fix incorrect quoting for infotext values with colon in them
* fix "hires. fix" prompt sharing same labels with txt2img_prompt
* Fix s_min_uncond default type int
* Fix for #10643 (Inpainting mask sometimes not working)
* fix bad styling for thumbs view in extra networks #10639
* fix for empty list of optimizations #10605
* small fixes to prepare_tcmalloc for Debian/Ubuntu compatibility
* fix --ui-debug-mode exit
* patch GitPython to not use leaky persistent processes
* fix duplicate Cross attention optimization after UI reload
* torch.cuda.is_available() check for SdOptimizationXformers
* fix hires fix using wrong conds in second pass if using Loras.
* handle exception when parsing generation parameters from png info
* fix upcast attention dtype error
* forcing Torch Version to 1.13.1 for RX 5000 series GPUs
* split mask blur into X and Y components, patch Outpainting MK2 accordingly
* don't die when a LoRA is a broken symlink
* allow activation of Generate Forever during generation
## 1.3.2 ## 1.3.2
### Bug Fixes: ### Bug Fixes:
......
import os import os
import sys
import traceback
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401
...@@ -51,10 +49,8 @@ class UpscalerLDSR(Upscaler): ...@@ -51,10 +49,8 @@ class UpscalerLDSR(Upscaler):
try: try:
return LDSR(model, yaml) return LDSR(model, yaml)
except Exception: except Exception:
print("Error importing LDSR:", file=sys.stderr) errors.report("Error importing LDSR", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return None return None
def do_upscale(self, img, path): def do_upscale(self, img, path):
......
...@@ -10,7 +10,7 @@ from contextlib import contextmanager ...@@ -10,7 +10,7 @@ from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): ...@@ -91,8 +91,9 @@ class VQModel(pl.LightningModule):
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
......
...@@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): ...@@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
......
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
# where the license is as follows:
#
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE./
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits is False, "Only for interface compatible with Gumbel"
assert return_logits is False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
...@@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
assert len(params.items) > 0 assert params.items
names.append(params.items[0]) names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
......
...@@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): ...@@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk):
else: else:
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0: if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora return lora
...@@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): ...@@ -267,7 +267,7 @@ def load_loras(names, multipliers=None):
lora.multiplier = multipliers[i] if multipliers else 1.0 lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora) loaded_loras.append(lora)
if len(failed_to_load_loras) > 0: if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
...@@ -448,7 +448,11 @@ def list_available_loras(): ...@@ -448,7 +448,11 @@ def list_available_loras():
continue continue
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename) try:
entry = LoraOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
continue
available_loras[name] = entry available_loras[name] = entry
......
...@@ -13,7 +13,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -13,7 +13,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
lora.list_available_loras() lora.list_available_loras()
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias() alias = lora_on_disk.get_alias()
...@@ -27,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -27,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
import os.path import os.path
import sys import sys
import traceback
import PIL.Image import PIL.Image
import numpy as np import numpy as np
...@@ -10,8 +9,9 @@ from tqdm import tqdm ...@@ -10,8 +9,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts from modules.shared import opts
...@@ -38,8 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -38,8 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
except Exception: except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr) errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
if add_model2: if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2) scalers.append(scaler_data2)
......
This diff is collapsed.
import gradio as gr
from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
.canvas-tooltip-info {
position: absolute;
top: 10px;
left: 10px;
cursor: help;
background-color: rgba(0, 0, 0, 0.3);
width: 20px;
height: 20px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
flex-direction: column;
z-index: 100;
}
.canvas-tooltip-info::after {
content: '';
display: block;
width: 2px;
height: 7px;
background-color: white;
margin-top: 2px;
}
.canvas-tooltip-info::before {
content: '';
display: block;
width: 2px;
height: 2px;
background-color: white;
}
.canvas-tooltip-content {
display: none;
background-color: #f9f9f9;
color: #333;
border: 1px solid #ddd;
padding: 15px;
position: absolute;
top: 40px;
left: 10px;
width: 250px;
font-size: 16px;
opacity: 0;
border-radius: 8px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 100;
}
.canvas-tooltip:hover .canvas-tooltip-content {
display: block;
animation: fadeIn 0.5s;
opacity: 1;
}
@keyframes fadeIn {
from {opacity: 0;}
to {opacity: 1;}
}
import gradio as gr
from modules import scripts, shared, ui_components, ui_settings
from modules.ui_components import FormColumn
class ExtraOptionsSection(scripts.Script):
section = "extra_options"
def __init__(self):
self.comps = None
self.setting_names = None
def title(self):
return "Extra options"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
self.comps.append(comp)
self.setting_names.append(setting_name)
def get_settings_values():
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
return self.comps
def before_process(self, p, *args):
for name, value in zip(self.setting_names, args):
if name not in p.override_settings:
p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
}))
<div class='card' style={style} onclick={card_clicked}> <div class='card' style={style} onclick={card_clicked} {sort_keys}>
{background_image} {background_image}
{metadata_button} {metadata_button}
<div class='actions'> <div class='actions'>
......
<div> <div>
<a href="/docs">API</a> <a href="{api_docs}">API</a>
 •   • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a> <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 •   • 
<a href="https://gradio.app">Gradio</a> <a href="https://gradio.app">Gradio</a>
 •   • 
<a href="#" onclick="showProfile('./internal/profile-startup'); return false;">Startup profile</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a> <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div> </div>
<br /> <br />
......
...@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) { ...@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) {
} }
onUiUpdate(function() { onAfterUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (arPreviewRect) { if (arPreviewRect) {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
......
...@@ -148,12 +148,18 @@ var addContextMenuEventListener = initResponse[2]; ...@@ -148,12 +148,18 @@ var addContextMenuEventListener = initResponse[2];
500); 500);
}; };
appendContextMenuOption('#txt2img_generate', 'Generate forever', function() { let generateOnRepeat_txt2img = function() {
generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
}); };
appendContextMenuOption('#img2img_generate', 'Generate forever', function() {
let generateOnRepeat_img2img = function() {
generateOnRepeat('#img2img_generate', '#img2img_interrupt'); generateOnRepeat('#img2img_generate', '#img2img_interrupt');
}); };
appendContextMenuOption('#txt2img_generate', 'Generate forever', generateOnRepeat_txt2img);
appendContextMenuOption('#txt2img_interrupt', 'Generate forever', generateOnRepeat_txt2img);
appendContextMenuOption('#img2img_generate', 'Generate forever', generateOnRepeat_img2img);
appendContextMenuOption('#img2img_interrupt', 'Generate forever', generateOnRepeat_img2img);
let cancelGenerateForever = function() { let cancelGenerateForever = function() {
clearInterval(window.generateOnRepeatInterval); clearInterval(window.generateOnRepeatInterval);
...@@ -167,6 +173,4 @@ var addContextMenuEventListener = initResponse[2]; ...@@ -167,6 +173,4 @@ var addContextMenuEventListener = initResponse[2];
})(); })();
//End example Context Menu Items //End example Context Menu Items
onUiUpdate(function() { onAfterUiUpdate(addContextMenuEventListener);
addContextMenuEventListener();
});
...@@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) { ...@@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) {
} }
} }
function eventHasFiles(e) {
if (!e.dataTransfer || !e.dataTransfer.files) return false;
if (e.dataTransfer.files.length > 0) return true;
if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
return false;
}
function dragDropTargetIsPrompt(target) {
if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
return false;
}
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); if (!eventHasFiles(e)) return;
if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; var targetImage = target.closest('[data-testid="image"]');
} if (!dragDropTargetIsPrompt(target) && !targetImage) return;
e.stopPropagation(); e.stopPropagation();
e.preventDefault(); e.preventDefault();
e.dataTransfer.dropEffect = 'copy'; e.dataTransfer.dropEffect = 'copy';
...@@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => { ...@@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) { if (!eventHasFiles(e)) return;
return;
if (dragDropTargetIsPrompt(target)) {
e.stopPropagation();
e.preventDefault();
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
} }
const imgWrap = target.closest('[data-testid="image"]');
if (!imgWrap) { var targetImage = target.closest('[data-testid="image"]');
if (targetImage) {
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage(targetImage, files);
return; return;
} }
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage(imgWrap, files);
}); });
window.addEventListener('paste', e => { window.addEventListener('paste', e => {
......
...@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) ...@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
} }
return [confirmed, config_state_name, config_restore_type]; return [confirmed, config_state_name, config_restore_type];
} }
function toggle_all_extensions(event) {
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
checkbox_el.checked = event.target.checked;
});
}
function toggle_extension() {
let all_extensions_toggled = true;
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
if (!checkbox_el.checked) {
all_extensions_toggled = false;
break;
}
}
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
}
...@@ -3,10 +3,17 @@ function setupExtraNetworksForTab(tabname) { ...@@ -3,10 +3,17 @@ function setupExtraNetworksForTab(tabname) {
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
var sort = gradioApp().getElementById(tabname + '_extra_sort');
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
search.classList.add('search'); search.classList.add('search');
sort.classList.add('sort');
sortOrder.classList.add('sortorder');
sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(search); tabs.appendChild(search);
tabs.appendChild(sort);
tabs.appendChild(sortOrder);
tabs.appendChild(refresh); tabs.appendChild(refresh);
var applyFilter = function() { var applyFilter = function() {
...@@ -26,8 +33,51 @@ function setupExtraNetworksForTab(tabname) { ...@@ -26,8 +33,51 @@ function setupExtraNetworksForTab(tabname) {
}); });
}; };
var applySort = function() {
var reverse = sortOrder.classList.contains("sortReverse");
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
return;
}
sort.dataset.sortkey = sortKeyStore;
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) {
card.originalParentElement = card.parentElement;
});
var sortedCards = Array.from(cards);
sortedCards.sort(function(cardA, cardB) {
var a = cardA.dataset[sortKey];
var b = cardB.dataset[sortKey];
if (!isNaN(a) && !isNaN(b)) {
return parseInt(a) - parseInt(b);
}
return (a < b ? -1 : (a > b ? 1 : 0));
});
if (reverse) {
sortedCards.reverse();
}
cards.forEach(function(card) {
card.remove();
});
sortedCards.forEach(function(card) {
card.originalParentElement.appendChild(card);
});
};
search.addEventListener("input", applyFilter); search.addEventListener("input", applyFilter);
applyFilter(); applyFilter();
["change", "blur", "click"].forEach(function(evt) {
sort.querySelector("input").addEventListener(evt, applySort);
});
sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse");
applySort();
});
extraNetworksApplyFilter[tabname] = applyFilter; extraNetworksApplyFilter[tabname] = applyFilter;
} }
......
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined; let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function() { onAfterUiUpdate(function() {
if (!txt2img_gallery) { if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img"); txt2img_gallery = attachGalleryListeners("txt2img");
} }
......
...@@ -15,7 +15,7 @@ var titles = { ...@@ -15,7 +15,7 @@ var titles = {
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomized",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
...@@ -112,21 +112,29 @@ var titles = { ...@@ -112,21 +112,29 @@ var titles = {
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
}; };
function updateTooltipForSpan(span) { function updateTooltip(element) {
if (span.title) return; // already has a title if (element.title) return; // already has a title
let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; let text = element.textContent;
let tooltip = localization[titles[text]] || titles[text];
if (!tooltip) { if (!tooltip) {
tooltip = localization[titles[span.value]] || titles[span.value]; let value = element.value;
if (value) tooltip = localization[titles[value]] || titles[value];
} }
if (!tooltip) { if (!tooltip) {
for (const c of span.classList) { // Gradio dropdown options have `data-value`.
let dataValue = element.dataset.value;
if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
}
if (!tooltip) {
for (const c of element.classList) {
if (c in titles) { if (c in titles) {
tooltip = localization[titles[c]] || titles[c]; tooltip = localization[titles[c]] || titles[c];
break; break;
...@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) { ...@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
} }
if (tooltip) { if (tooltip) {
span.title = tooltip; element.title = tooltip;
} }
} }
function updateTooltipForSelect(select) { // Nodes to check for adding tooltips.
if (select.onchange != null) return; const tooltipCheckNodes = new Set();
// Timer for debouncing tooltip check.
let tooltipCheckTimer = null;
select.onchange = function() { function processTooltipCheckNodes() {
select.title = localization[titles[select.value]] || titles[select.value] || ""; for (const node of tooltipCheckNodes) {
}; updateTooltip(node);
}
tooltipCheckNodes.clear();
} }
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1}; onUiUpdate(function(mutationRecords) {
for (const record of mutationRecords) {
onUiUpdate(function(m) { if (record.type === "childList" && record.target.classList.contains("options")) {
m.forEach(function(record) { // This smells like a Gradio dropdown menu having changed,
record.addedNodes.forEach(function(node) { // so let's enqueue an update for the input element that shows the current value.
if (observedTooltipElements[node.tagName]) { let wrap = record.target.parentNode;
updateTooltipForSpan(node); let input = wrap?.querySelector("input");
} if (input) {
if (node.tagName == "SELECT") { input.title = ""; // So we'll even have a chance to update it.
updateTooltipForSelect(node); tooltipCheckNodes.add(input);
} }
}
if (node.querySelectorAll) { for (const node of record.addedNodes) {
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan); if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
node.querySelectorAll('select').forEach(updateTooltipForSelect); if (!node.title) {
if (
node.tagName === "SPAN" ||
node.tagName === "BUTTON" ||
node.tagName === "P" ||
node.tagName === "INPUT" ||
(node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
) {
tooltipCheckNodes.add(node);
}
}
node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
} }
}); }
}); }
if (tooltipCheckNodes.size) {
clearTimeout(tooltipCheckTimer);
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
}
}); });
...@@ -39,5 +39,5 @@ function imageMaskResize() { ...@@ -39,5 +39,5 @@ function imageMaskResize() {
}); });
} }
onUiUpdate(imageMaskResize); onAfterUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize); window.addEventListener('resize', imageMaskResize);
window.onload = (function() {
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
e.stopPropagation();
e.preventDefault();
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
});
});
...@@ -170,7 +170,7 @@ function modalTileImageToggle(event) { ...@@ -170,7 +170,7 @@ function modalTileImageToggle(event) {
event.stopPropagation(); event.stopPropagation();
} }
onUiUpdate(function() { onAfterUiUpdate(function() {
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img'); var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox); fullImg_preview.forEach(setupImageForLightbox);
......
let gamepads = [];
window.addEventListener('gamepadconnected', (e) => { window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index; const index = e.gamepad.index;
let isWaiting = false; let isWaiting = false;
setInterval(async() => { gamepads[index] = setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return; if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index]; const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0]; const xValue = gamepad.axes[0];
...@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => { ...@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
}, 10); }, 10);
}); });
window.addEventListener('gamepaddisconnected', (e) => {
clearInterval(gamepads[e.gamepad.index]);
});
/* /*
Primarily for vr controller type pointer devices. Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr. I use the wheel event because there's currently no way to do it properly with web xr.
......
...@@ -4,7 +4,7 @@ let lastHeadImg = null; ...@@ -4,7 +4,7 @@ let lastHeadImg = null;
let notificationButton = null; let notificationButton = null;
onUiUpdate(function() { onAfterUiUpdate(function() {
if (notificationButton == null) { if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications'); notificationButton = gradioApp().getElementById('request_notifications');
......
function createRow(table, cellName, items) {
var tr = document.createElement('tr');
var res = [];
items.forEach(function(x, i) {
if (x === undefined) {
res.push(null);
return;
}
var td = document.createElement(cellName);
td.textContent = x;
tr.appendChild(td);
res.push(td);
var colspan = 1;
for (var n = i + 1; n < items.length; n++) {
if (items[n] !== undefined) {
break;
}
colspan += 1;
}
if (colspan > 1) {
td.colSpan = colspan;
}
});
table.appendChild(tr);
return res;
}
function showProfile(path, cutoff = 0.05) {
requestGet(path, {}, function(data) {
var table = document.createElement('table');
table.className = 'popup-table';
data.records['total'] = data.total;
var keys = Object.keys(data.records).sort(function(a, b) {
return data.records[b] - data.records[a];
});
var items = keys.map(function(x) {
return {key: x, parts: x.split('/'), time: data.records[x]};
});
var maxLength = items.reduce(function(a, b) {
return Math.max(a, b.parts.length);
}, 0);
var cols = createRow(table, 'th', ['record', 'seconds']);
cols[0].colSpan = maxLength;
function arraysEqual(a, b) {
return !(a < b || b < a);
}
var addLevel = function(level, parent, hide) {
var matching = items.filter(function(x) {
return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
});
var sorted = matching.sort(function(a, b) {
return b.time - a.time;
});
var othersTime = 0;
var othersList = [];
var othersRows = [];
var childrenRows = [];
sorted.forEach(function(x) {
var visible = x.time >= cutoff && !hide;
var cells = [];
for (var i = 0; i < maxLength; i++) {
cells.push(x.parts[i]);
}
cells.push(x.time.toFixed(3));
var cols = createRow(table, 'td', cells);
for (i = 0; i < level; i++) {
cols[i].className = 'muted';
}
var tr = cols[0].parentNode;
if (!visible) {
tr.classList.add("hidden");
}
if (x.time >= cutoff) {
childrenRows.push(tr);
} else {
othersTime += x.time;
othersList.push(x.parts[level]);
othersRows.push(tr);
}
var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);
if (children.length > 0) {
var cell = cols[level];
var onclick = function() {
cell.classList.remove("link");
cell.removeEventListener("click", onclick);
children.forEach(function(x) {
x.classList.remove("hidden");
});
};
cell.classList.add("link");
cell.addEventListener("click", onclick);
}
});
if (othersTime > 0) {
var cells = [];
for (var i = 0; i < maxLength; i++) {
cells.push(parent[i]);
}
cells.push(othersTime.toFixed(3));
cells[level] = 'others';
var cols = createRow(table, 'td', cells);
for (i = 0; i < level; i++) {
cols[i].className = 'muted';
}
var cell = cols[level];
var tr = cell.parentNode;
var onclick = function() {
tr.classList.add("hidden");
cell.classList.remove("link");
cell.removeEventListener("click", onclick);
othersRows.forEach(function(x) {
x.classList.remove("hidden");
});
};
cell.title = othersList.join(", ");
cell.classList.add("link");
cell.addEventListener("click", onclick);
if (hide) {
tr.classList.add("hidden");
}
childrenRows.push(tr);
}
return childrenRows;
};
addLevel(0, []);
popup(table);
});
}
let promptTokenCountDebounceTime = 800;
let promptTokenCountTimeouts = {};
var promptTokenCountUpdateFunctions = {};
function update_txt2img_tokens(...args) {
// Called from Gradio
update_token_counter("txt2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_img2img_tokens(...args) {
// Called from Gradio
update_token_counter("img2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_token_counter(button_id) {
if (opts.disable_token_counters) {
return;
}
if (promptTokenCountTimeouts[button_id]) {
clearTimeout(promptTokenCountTimeouts[button_id]);
}
promptTokenCountTimeouts[button_id] = setTimeout(
() => gradioApp().getElementById(button_id)?.click(),
promptTokenCountDebounceTime,
);
}
function recalculatePromptTokens(name) {
promptTokenCountUpdateFunctions[name]?.();
}
function recalculate_prompts_txt2img() {
// Called from Gradio
recalculatePromptTokens('txt2img_prompt');
recalculatePromptTokens('txt2img_neg_prompt');
return Array.from(arguments);
}
function recalculate_prompts_img2img() {
// Called from Gradio
recalculatePromptTokens('img2img_prompt');
recalculatePromptTokens('img2img_neg_prompt');
return Array.from(arguments);
}
function setupTokenCounting(id, id_counter, id_button) {
var prompt = gradioApp().getElementById(id);
var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
if (opts.disable_token_counters) {
counter.style.display = "none";
return;
}
if (counter.parentElement == prompt.parentElement) {
return;
}
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
promptTokenCountUpdateFunctions[id] = function() {
update_token_counter(id_button);
};
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
}
function setupTokenCounters() {
setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
}
...@@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) { ...@@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) {
} }
var promptTokecountUpdateFuncs = {};
function recalculatePromptTokens(name) {
if (promptTokecountUpdateFuncs[name]) {
promptTokecountUpdateFuncs[name]();
}
}
function recalculate_prompts_txt2img() {
recalculatePromptTokens('txt2img_prompt');
recalculatePromptTokens('txt2img_neg_prompt');
return Array.from(arguments);
}
function recalculate_prompts_img2img() {
recalculatePromptTokens('img2img_prompt');
recalculatePromptTokens('img2img_neg_prompt');
return Array.from(arguments);
}
var opts = {}; var opts = {};
onUiUpdate(function() { onAfterUiUpdate(function() {
if (Object.keys(opts).length != 0) return; if (Object.keys(opts).length != 0) return;
var json_elem = gradioApp().getElementById('settings_json'); var json_elem = gradioApp().getElementById('settings_json');
...@@ -302,28 +281,7 @@ onUiUpdate(function() { ...@@ -302,28 +281,7 @@ onUiUpdate(function() {
json_elem.parentElement.style.display = "none"; json_elem.parentElement.style.display = "none";
function registerTextarea(id, id_counter, id_button) { setupTokenCounters();
var prompt = gradioApp().getElementById(id);
var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if (counter.parentElement == prompt.parentElement) {
return;
}
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
promptTokecountUpdateFuncs[id] = function() {
update_token_counter(id_button);
};
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
var settings_tabs = gradioApp().querySelector('#settings div'); var settings_tabs = gradioApp().querySelector('#settings div');
...@@ -354,33 +312,6 @@ onOptionsChanged(function() { ...@@ -354,33 +312,6 @@ onOptionsChanged(function() {
}); });
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800;
let token_timeouts = {};
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_img2img_tokens(...args) {
update_token_counter(
"img2img_token_button"
);
if (args.length == 2) {
return args[0];
}
return args;
}
function update_token_counter(button_id) {
if (token_timeouts[button_id]) {
clearTimeout(token_timeouts[button_id]);
}
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
function restart_reload() { function restart_reload() {
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
......
...@@ -42,7 +42,7 @@ onOptionsChanged(function() { ...@@ -42,7 +42,7 @@ onOptionsChanged(function() {
function settingsHintsShowQuicksettings() { function settingsHintsShowQuicksettings() {
requestGet("./internal/quicksettings-hint", {}, function(data) { requestGet("./internal/quicksettings-hint", {}, function(data) {
var table = document.createElement('table'); var table = document.createElement('table');
table.className = 'settings-value-table'; table.className = 'popup-table';
data.forEach(function(obj) { data.forEach(function(obj) {
var tr = document.createElement('tr'); var tr = document.createElement('tr');
......
...@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder ...@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
...@@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess ...@@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
...@@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): ...@@ -108,7 +109,6 @@ def api_middleware(app: FastAPI):
from rich.console import Console from rich.console import Console
console = Console() console = Console()
except Exception: except Exception:
import traceback
rich_available = False rich_available = False
@app.middleware("http") @app.middleware("http")
...@@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): ...@@ -139,11 +139,12 @@ def api_middleware(app: FastAPI):
"errors": str(e), "errors": str(e),
} }
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
print(f"API error: {request.method}: {request.url} {err}") message = f"API error: {request.method}: {request.url} {err}"
if rich_available: if rich_available:
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else: else:
traceback.print_exc() errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http") @app.middleware("http")
...@@ -188,7 +189,9 @@ class Api: ...@@ -188,7 +189,9 @@ class Api:
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
...@@ -206,6 +209,11 @@ class Api: ...@@ -206,6 +209,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
if shared.cmd_opts.add_stop_route:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
self.default_script_arg_txt2img = [] self.default_script_arg_txt2img = []
self.default_script_arg_img2img = [] self.default_script_arg_img2img = []
...@@ -278,7 +286,7 @@ class Api: ...@@ -278,7 +286,7 @@ class Api:
script_args[0] = selectable_idx + 1 script_args[0] = selectable_idx + 1
# Now check for always on scripts # Now check for always on scripts
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): if request.alwayson_scripts:
for alwayson_script_name in request.alwayson_scripts.keys(): for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner) alwayson_script = self.get_script(alwayson_script_name, script_runner)
if alwayson_script is None: if alwayson_script is None:
...@@ -538,9 +546,20 @@ class Api: ...@@ -538,9 +546,20 @@ class Api:
for upscaler in shared.sd_upscalers for upscaler in shared.sd_upscalers
] ]
def get_latent_upscale_modes(self):
return [
{
"name": upscale_mode,
}
for upscale_mode in [*(shared.latent_upscale_modes or {})]
]
def get_sd_models(self): def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_sd_vaes(self):
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
...@@ -700,4 +719,16 @@ class Api: ...@@ -700,4 +719,16 @@ class Api:
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
def kill_webui(self):
restart.stop_program()
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
return Response(status_code=501)
def stop_webui(request):
shared.state.server_command = "stop"
return Response("Stopping.")
...@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel): ...@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel):
model_url: Optional[str] = Field(title="URL") model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale") scale: Optional[float] = Field(title="Scale")
class LatentUpscalerModeItem(BaseModel):
name: str = Field(title="Name")
class SDModelItem(BaseModel): class SDModelItem(BaseModel):
title: str = Field(title="Title") title: str = Field(title="Title")
model_name: str = Field(title="Model Name") model_name: str = Field(title="Model Name")
...@@ -249,6 +252,10 @@ class SDModelItem(BaseModel): ...@@ -249,6 +252,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class SDVaeItem(BaseModel):
model_name: str = Field(title="Model Name")
filename: str = Field(title="Filename")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
path: Optional[str] = Field(title="Path") path: Optional[str] = Field(title="Path")
......
from functools import wraps from functools import wraps
import html import html
import sys
import threading import threading
import traceback
import time import time
from modules import shared, progress from modules import shared, progress, errors
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -25,7 +23,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -25,7 +23,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
id_task = args[0] id_task = args[0]
progress.add_task_to_queue(id_task) progress.add_task_to_queue(id_task)
else: else:
...@@ -59,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -59,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try: try:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text # When printing out our debug argument list,
max_debug_str_len = 131072 # (1024*1024)/8 # do not print out more than a 100 KB of text
max_debug_str_len = 131072
print("Error completing request", file=sys.stderr) message = "Error completing request"
argStr = f"Arguments: {args} {kwargs}" arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
print(argStr[:max_debug_str_len], file=sys.stderr) if len(arg_str) > max_debug_str_len:
if len(argStr) > max_debug_str_len: arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) errors.report(f"{message}\n{arg_str}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
...@@ -111,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -111,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res) return tuple(res)
return f return f
...@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la ...@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly") parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed") parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
...@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch ...@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api')
import os import os
import sys
import traceback
import cv2 import cv2
import torch import torch
import modules.face_restoration import modules.face_restoration
import modules.shared import modules.shared
from modules import shared, devices, modelloader from modules import shared, devices, modelloader, errors
from modules.paths import models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes # codeformer people made a choice to include modified basicsr library to their project which makes
...@@ -22,9 +20,7 @@ codeformer = None ...@@ -22,9 +20,7 @@ codeformer = None
def setup_model(dirname): def setup_model(dirname):
global model_path os.makedirs(model_path, exist_ok=True)
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
...@@ -105,8 +101,8 @@ def setup_model(dirname): ...@@ -105,8 +101,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as error: except Exception:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8') restored_face = restored_face.astype('uint8')
...@@ -135,7 +131,6 @@ def setup_model(dirname): ...@@ -135,7 +131,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
print("Error setting up CodeFormer:", file=sys.stderr) errors.report("Error setting up CodeFormer", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
# sys.path = stored_sys_path # sys.path = stored_sys_path
...@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c ...@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
""" """
import os import os
import sys
import traceback
import json import json
import time import time
import tqdm import tqdm
...@@ -13,7 +11,7 @@ from datetime import datetime ...@@ -13,7 +11,7 @@ from datetime import datetime
from collections import OrderedDict from collections import OrderedDict
import git import git
from modules import shared, extensions from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir from modules.paths_internal import script_path, config_states_dir
...@@ -53,8 +51,7 @@ def get_webui_config(): ...@@ -53,8 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")): if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path) webui_repo = git.Repo(script_path)
except Exception: except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr) errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
webui_remote = None webui_remote = None
webui_commit_hash = None webui_commit_hash = None
...@@ -134,8 +131,7 @@ def restore_webui_config(config): ...@@ -134,8 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")): if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path) webui_repo = git.Repo(script_path)
except Exception: except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr) errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return return
try: try:
...@@ -143,8 +139,7 @@ def restore_webui_config(config): ...@@ -143,8 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True) webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.") print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception: except Exception:
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) errors.report(f"Error restoring webui to commit{webui_commit_hash}")
print(traceback.format_exc(), file=sys.stderr)
def restore_extension_config(config): def restore_extension_config(config):
......
import sys import sys
import contextlib import contextlib
from functools import lru_cache
import torch import torch
from modules import errors from modules import errors
...@@ -154,3 +156,19 @@ def test_for_nans(x, where): ...@@ -154,3 +156,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check." message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
import sys import sys
import textwrap
import traceback import traceback
exception_records = []
def record_exception():
_, e, tb = sys.exc_info()
if e is None:
return
if exception_records and exception_records[-1] == e:
return
exception_records.append((e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
record_exception()
for line in message.splitlines():
print("***", line, file=sys.stderr)
if exc_info:
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
print("---", file=sys.stderr)
def print_error_explanation(message): def print_error_explanation(message):
record_exception()
lines = message.strip().split("\n") lines = message.strip().split("\n")
max_len = max([len(x) for x in lines]) max_len = max([len(x) for x in lines])
...@@ -12,9 +46,15 @@ def print_error_explanation(message): ...@@ -12,9 +46,15 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr) print('=' * max_len, file=sys.stderr)
def display(e: Exception, task): def display(e: Exception, task, *, full_traceback=False):
record_exception()
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) te = traceback.TracebackException.from_exception(e)
if full_traceback:
# include frames leading up to the try-catch block
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
print(*te.format(), sep="", file=sys.stderr)
message = str(e) message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
...@@ -28,6 +68,8 @@ already_displayed = {} ...@@ -28,6 +68,8 @@ already_displayed = {}
def display_once(e: Exception, task): def display_once(e: Exception, task):
record_exception()
if task in already_displayed: if task in already_displayed:
return return
......
import os import os
import sys
import threading import threading
import traceback
import git from modules import shared, errors
from modules.gitpython_hack import Repo
from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = [] extensions = []
if not os.path.exists(extensions_dir): os.makedirs(extensions_dir, exist_ok=True)
os.makedirs(extensions_dir)
def active(): def active():
...@@ -54,10 +50,9 @@ class Extension: ...@@ -54,10 +50,9 @@ class Extension:
repo = None repo = None
try: try:
if os.path.exists(os.path.join(self.path, ".git")): if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(self.path) repo = Repo(self.path)
except Exception: except Exception:
print(f"Error reading github repository info from {self.path}:", file=sys.stderr) errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
...@@ -72,8 +67,8 @@ class Extension: ...@@ -72,8 +67,8 @@ class Extension:
self.commit_hash = commit.hexsha self.commit_hash = commit.hexsha
self.version = self.commit_hash[:8] self.version = self.commit_hash[:8]
except Exception as ex: except Exception:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None self.remote = None
self.have_info_from_repo = True self.have_info_from_repo = True
...@@ -94,7 +89,7 @@ class Extension: ...@@ -94,7 +89,7 @@ class Extension:
return res return res
def check_updates(self): def check_updates(self):
repo = git.Repo(self.path) repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True): for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
...@@ -116,7 +111,7 @@ class Extension: ...@@ -116,7 +111,7 @@ class Extension:
self.status = "latest" self.status = "latest"
def fetch_and_reset_hard(self, commit='origin'): def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path) repo = Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True) repo.git.fetch(all=True)
......
...@@ -32,6 +32,9 @@ class ExtraNetworkParams: ...@@ -32,6 +32,9 @@ class ExtraNetworkParams:
else: else:
self.positional.append(item) self.positional.append(item)
def __eq__(self, other):
return self.items == other.items
class ExtraNetwork: class ExtraNetwork:
def __init__(self, name): def __init__(self, name):
......
...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
assert len(params.items) > 0 assert params.items
names.append(params.items[0]) names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
......
...@@ -55,7 +55,7 @@ def image_from_url_text(filedata): ...@@ -55,7 +55,7 @@ def image_from_url_text(filedata):
if filedata is None: if filedata is None:
return None return None
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0] filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False): if type(filedata) == dict and filedata.get("is_file", False):
...@@ -265,19 +265,30 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -265,19 +265,30 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
prompt += ("" if prompt == "" else "\n") + line prompt += ("" if prompt == "" else "\n") + line
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
if shared.opts.infotext_styles == "Apply":
res["Styles array"] = found_styles
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
res["Styles array"] = found_styles
res["Prompt"] = prompt res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
if v[0] == '"' and v[-1] == '"': try:
v = unquote(v) if v[0] == '"' and v[-1] == '"':
v = unquote(v)
m = re_imagesize.match(v)
if m is not None: m = re_imagesize.match(v)
res[f"{k}-1"] = m.group(1) if m is not None:
res[f"{k}-2"] = m.group(2) res[f"{k}-1"] = m.group(1)
else: res[f"{k}-2"] = m.group(2)
res[k] = v else:
res[k] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# Missing CLIP skip means it was set to 1 (the default) # Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res: if "Clip skip" not in res:
...@@ -306,6 +317,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -306,6 +317,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res: if "RNG" not in res:
res["RNG"] = "GPU" res["RNG"] = "GPU"
if "Schedule type" not in res:
res["Schedule type"] = "Automatic"
if "Schedule max sigma" not in res:
res["Schedule max sigma"] = 0
if "Schedule min sigma" not in res:
res["Schedule min sigma"] = 0
if "Schedule rho" not in res:
res["Schedule rho"] = 0
return res return res
...@@ -318,6 +341,10 @@ infotext_to_setting_name_mapping = [ ...@@ -318,6 +341,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'), ('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'), ('ENSD', 'eta_noise_seed_delta'),
('Schedule type', 'k_sched_type'),
('Schedule max sigma', 'sigma_max'),
('Schedule min sigma', 'sigma_min'),
('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'), ('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'), ('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'), ('Eta DDIM', 'eta_ddim'),
...@@ -330,6 +357,7 @@ infotext_to_setting_name_mapping = [ ...@@ -330,6 +357,7 @@ infotext_to_setting_name_mapping = [
('Token merging ratio hr', 'token_merging_ratio_hr'), ('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
] ]
...@@ -421,7 +449,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -421,7 +449,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
vals_pairs = [f"{k}: {v}" for k, v in vals.items()] vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
paste_fields = paste_fields + [(override_settings_component, paste_settings)] paste_fields = paste_fields + [(override_settings_component, paste_settings)]
...@@ -438,5 +466,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -438,5 +466,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[], outputs=[],
show_progress=False, show_progress=False,
) )
import os import os
import sys
import traceback
import facexlib import facexlib
import gfpgan import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import paths, shared, devices, modelloader from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
...@@ -72,11 +70,8 @@ gfpgan_constructor = None ...@@ -72,11 +70,8 @@ gfpgan_constructor = None
def setup_model(dirname): def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
try: try:
os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer from gfpgan import GFPGANer
from facexlib import detection, parsing # noqa: F401 from facexlib import detection, parsing # noqa: F401
global user_path global user_path
...@@ -112,5 +107,4 @@ def setup_model(dirname): ...@@ -112,5 +107,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN()) shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
print("Error setting up GFPGAN:", file=sys.stderr) errors.report("Error setting up GFPGAN", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
from __future__ import annotations
import io
import subprocess
import git
class Git(git.Git):
"""
Git subclassed to never use persistent processes.
"""
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
ret = subprocess.check_output(
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
input=self._prepare_ref(ref),
cwd=self._working_dir,
timeout=2,
)
return self._parse_object_header(ret)
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
# Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects).
ret = subprocess.check_output(
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
input=self._prepare_ref(ref),
cwd=self._working_dir,
timeout=30,
)
bio = io.BytesIO(ret)
hexsha, typename, size = self._parse_object_header(bio.readline())
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
class Repo(git.Repo):
GitCommandWrapperType = Git
...@@ -2,8 +2,6 @@ import datetime ...@@ -2,8 +2,6 @@ import datetime
import glob import glob
import html import html
import os import os
import sys
import traceback
import inspect import inspect
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
...@@ -11,7 +9,7 @@ import torch ...@@ -11,7 +9,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -325,17 +323,14 @@ def load_hypernetwork(name): ...@@ -325,17 +323,14 @@ def load_hypernetwork(name):
if path is None: if path is None:
return None return None
hypernetwork = Hypernetwork()
try: try:
hypernetwork = Hypernetwork()
hypernetwork.load(path) hypernetwork.load(path)
return hypernetwork
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) errors.report(f"Error loading hypernetwork {path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return None return None
return hypernetwork
def load_hypernetworks(names, multipliers=None): def load_hypernetworks(names, multipliers=None):
already_loaded = {} already_loaded = {}
...@@ -770,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -770,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) errors.report("Exception in training hypernetwork", exc_info=True)
finally: finally:
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
......
import datetime import datetime
import sys
import traceback
import pytz import pytz
import io import io
...@@ -21,6 +19,8 @@ from modules import sd_samplers, shared, script_callbacks, errors ...@@ -21,6 +19,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file from modules.paths_internal import roboto_ttf_file
from modules.shared import opts from modules.shared import opts
import modules.sd_vae as sd_vae
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
...@@ -336,8 +336,20 @@ def sanitize_filename_part(text, replace_spaces=True): ...@@ -336,8 +336,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator: class FilenameGenerator:
def get_vae_filename(self): #get the name of the VAE file.
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
replacements = { replacements = {
'seed': lambda self: self.seed if self.seed is not None else '', 'seed': lambda self: self.seed if self.seed is not None else '',
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps, 'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale, 'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
...@@ -354,20 +366,23 @@ class FilenameGenerator: ...@@ -354,20 +366,23 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(), 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(), 'prompt_words': lambda self: self.prompt_words(),
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1, 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'batch_size': lambda self: self.p.batch_size,
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user, 'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
def __init__(self, p, seed, prompt, image): def __init__(self, p, seed, prompt, image, zip=False):
self.p = p self.p = p
self.seed = seed self.seed = seed
self.prompt = prompt self.prompt = prompt
self.image = image self.image = image
self.zip = zip
def hasprompt(self, *args): def hasprompt(self, *args):
lower = self.prompt.lower() lower = self.prompt.lower()
...@@ -391,7 +406,7 @@ class FilenameGenerator: ...@@ -391,7 +406,7 @@ class FilenameGenerator:
prompt_no_style = self.prompt prompt_no_style = self.prompt
for style in shared.prompt_styles.get_style_prompts(self.p.styles): for style in shared.prompt_styles.get_style_prompts(self.p.styles):
if len(style) > 0: if style:
for part in style.split("{prompt}"): for part in style.split("{prompt}"):
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
...@@ -400,7 +415,7 @@ class FilenameGenerator: ...@@ -400,7 +415,7 @@ class FilenameGenerator:
return sanitize_filename_part(prompt_no_style, replace_spaces=False) return sanitize_filename_part(prompt_no_style, replace_spaces=False)
def prompt_words(self): def prompt_words(self):
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] words = [x for x in re_nonletters.split(self.prompt or "") if x]
if len(words) == 0: if len(words) == 0:
words = ["empty"] words = ["empty"]
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
...@@ -408,7 +423,7 @@ class FilenameGenerator: ...@@ -408,7 +423,7 @@ class FilenameGenerator:
def datetime(self, *args): def datetime(self, *args):
time_datetime = datetime.datetime.now() time_datetime = datetime.datetime.now()
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format time_format = args[0] if (args and args[0] != "") else self.default_time_format
try: try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError: except pytz.exceptions.UnknownTimeZoneError:
...@@ -447,8 +462,7 @@ class FilenameGenerator: ...@@ -447,8 +462,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args) replacement = fun(self, *pattern_args)
except Exception: except Exception:
replacement = None replacement = None
print(f"Error adding [{pattern}] to filename", file=sys.stderr) errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue continue
...@@ -665,9 +679,10 @@ def read_info_from_image(image): ...@@ -665,9 +679,10 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment items['exif comment'] = exif_comment
geninfo = exif_comment geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']: 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
items.pop(field, None) 'icc_profile', 'chromaticity']:
items.pop(field, None)
if items.get("Software", None) == "NovelAI": if items.get("Software", None) == "NovelAI":
try: try:
...@@ -678,8 +693,7 @@ def read_info_from_image(image): ...@@ -678,8 +693,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]} Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception: except Exception:
print("Error parsing NovelAI image generation parameters:", file=sys.stderr) errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items return geninfo, items
......
import os import os
from pathlib import Path
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
...@@ -14,7 +15,7 @@ from modules.ui import plaintext_to_html ...@@ -14,7 +15,7 @@ from modules.ui import plaintext_to_html
import modules.scripts import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
processing.fix_seed(p) processing.fix_seed(p)
images = shared.listfiles(input_dir) images = shared.listfiles(input_dir)
...@@ -22,9 +23,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -22,9 +23,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
is_inpaint_batch = False is_inpaint_batch = False
if inpaint_mask_dir: if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir) inpaint_masks = shared.listfiles(inpaint_mask_dir)
is_inpaint_batch = len(inpaint_masks) > 0 is_inpaint_batch = bool(inpaint_masks)
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
...@@ -50,14 +52,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -50,14 +52,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
continue continue
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)
if to_scale:
p.width = int(img.width * scale_by)
p.height = int(img.height * scale_by)
p.init_images = [img] * p.batch_size p.init_images = [img] * p.batch_size
image_path = Path(image)
if is_inpaint_batch: if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching # try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) if len(inpaint_masks) == 1:
# if not found use first one ("same mask for all images" use-case)
if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0] mask_image_path = inpaint_masks[0]
else:
# try to find corresponding mask for an image using simple filename matching
mask_image_dir = Path(inpaint_mask_dir)
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
if len(masks_found) == 0:
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
continue
# it should contain only 1 matching mask
# otherwise user has many masks with the same name but different extensions
mask_image_path = masks_found[0]
mask_image = Image.open(mask_image_path) mask_image = Image.open(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image
...@@ -66,7 +85,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -66,7 +85,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
proc = process_images(p) proc = process_images(p)
for n, processed_image in enumerate(proc.images): for n, processed_image in enumerate(proc.images):
filename = os.path.basename(image) filename = image_path.name
if n > 0: if n > 0:
left, right = os.path.splitext(filename) left, right = os.path.splitext(filename)
...@@ -93,7 +112,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -93,7 +112,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB") image = image.convert("RGB")
elif mode == 3: # inpaint sketch elif mode == 3: # inpaint sketch
image = inpaint_color_sketch image = inpaint_color_sketch
...@@ -115,7 +135,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -115,7 +135,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None: if image is not None:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
if selected_scale_tab == 1: if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected" assert image, "Can't scale by because no image is selected"
width = int(image.width * scale_by) width = int(image.width * scale_by)
...@@ -172,7 +192,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -172,7 +192,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
else: else:
......
import os import os
import sys import sys
import traceback
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import re import re
...@@ -216,8 +215,7 @@ class InterrogateModels: ...@@ -216,8 +215,7 @@ class InterrogateModels:
res += f", {match}" res += f", {match}"
except Exception: except Exception:
print("Error interrogating", file=sys.stderr) errors.report("Error interrogating", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>" res += "<error>"
self.unload() self.unload()
......
...@@ -7,7 +7,7 @@ import platform ...@@ -7,7 +7,7 @@ import platform
import json import json
from functools import lru_cache from functools import lru_cache
from modules import cmd_args from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
...@@ -68,7 +68,13 @@ def git_tag(): ...@@ -68,7 +68,13 @@ def git_tag():
try: try:
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
return "<none>" try:
from pathlib import Path
changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
with changelog_md.open(encoding="utf-8") as file:
return next((line.strip() for line in file if line.strip()), "<none>")
except Exception:
return "<none>"
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
...@@ -141,10 +147,10 @@ def git_clone(url, dir, name, commithash=None): ...@@ -141,10 +147,10 @@ def git_clone(url, dir, name, commithash=None):
return return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
...@@ -188,7 +194,7 @@ def run_extension_installer(extension_dir): ...@@ -188,7 +194,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e: except Exception as e:
print(e, file=sys.stderr) errors.report(str(e))
def list_extensions(settings_file): def list_extensions(settings_file):
...@@ -198,8 +204,8 @@ def list_extensions(settings_file): ...@@ -198,8 +204,8 @@ def list_extensions(settings_file):
if os.path.isfile(settings_file): if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file: with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file) settings = json.load(file)
except Exception as e: except Exception:
print(e, file=sys.stderr) errors.report("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none') disable_all_extensions = settings.get('disable_all_extensions', 'none')
...@@ -223,23 +229,28 @@ def prepare_environment(): ...@@ -223,23 +229,28 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
except OSError:
pass
if not args.skip_python_version_check: if not args.skip_python_version_check:
check_python_version() check_python_version()
...@@ -286,7 +297,6 @@ def prepare_environment(): ...@@ -286,7 +297,6 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
......
import json import json
import os import os
import sys
import traceback
from modules import errors
localizations = {} localizations = {}
...@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: ...@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file: with open(fn, "r", encoding="utf8") as file:
data = json.load(file) data = json.load(file)
except Exception: except Exception:
print(f"Error loading localization from {fn}:", file=sys.stderr) errors.report(f"Error loading localization from {fn}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return f"window.localization = {json.dumps(data)}" return f"window.localization = {json.dumps(data)}"
...@@ -15,6 +15,8 @@ def send_everything_to_cpu(): ...@@ -15,6 +15,8 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
sd_model.lowvram = True
parents = {} parents = {}
def send_me_to_gpu(module, _): def send_me_to_gpu(module, _):
...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks: for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu) block.register_forward_pre_hook(send_me_to_gpu)
def is_enabled(sd_model):
return getattr(sd_model, 'lowvram', False)
...@@ -95,8 +95,7 @@ def cleanup_models(): ...@@ -95,8 +95,7 @@ def cleanup_models():
def move_files(src_path: str, dest_path: str, ext_filter: str = None): def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try: try:
if not os.path.exists(dest_path): os.makedirs(dest_path, exist_ok=True)
os.makedirs(dest_path)
if os.path.exists(src_path): if os.path.exists(src_path):
for file in os.listdir(src_path): for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file) fullpath = os.path.join(src_path, file)
......
...@@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): ...@@ -230,9 +230,9 @@ class DDPM(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
......
...@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl ...@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
......
This diff is collapsed.
...@@ -336,11 +336,11 @@ def parse_prompt_attention(text): ...@@ -336,11 +336,11 @@ def parse_prompt_attention(text):
round_brackets.append(len(res)) round_brackets.append(len(res))
elif text == '[': elif text == '[':
square_brackets.append(len(res)) square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0: elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight)) multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0: elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier) multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0: elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier) multiply_range(square_brackets.pop(), square_bracket_multiplier)
else: else:
parts = re.split(re_break, text) parts = re.split(re_break, text)
......
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -9,7 +7,8 @@ from realesrgan import RealESRGANer ...@@ -9,7 +7,8 @@ from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules import modelloader from modules import modelloader, errors
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
def __init__(self, path): def __init__(self, path):
...@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception: except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr) errors.report("Error importing Real-ESRGAN", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False self.enable = False
self.scalers = [] self.scalers = []
...@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info return info
except Exception as e: except Exception:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) errors.report("Error making Real-ESRGAN models list", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return None return None
def load_models(self, _): def load_models(self, _):
...@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ...@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
] ]
return models return models
except Exception: except Exception:
print("Error making Real-ESRGAN models list:", file=sys.stderr) errors.report("Error making Real-ESRGAN models list", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
import os
from pathlib import Path
from modules.paths_internal import script_path
def is_restartable() -> bool:
"""
Return True if the webui is restartable (i.e. there is something watching to restart it with)
"""
return bool(os.environ.get('SD_WEBUI_RESTART'))
def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
(Path(script_path) / "tmp" / "restart").touch()
stop_program()
def stop_program() -> None:
os._exit(0)
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import pickle import pickle
import collections import collections
import sys
import traceback
import torch import torch
import numpy import numpy
...@@ -11,7 +9,10 @@ import _codecs ...@@ -11,7 +9,10 @@ import _codecs
import zipfile import zipfile
import re import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
from modules import errors
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args): def encode(*args):
...@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ...@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler) check_pt(filename, extra_handler)
except pickle.UnpicklingError: except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) errors.report(
print(traceback.format_exc(), file=sys.stderr) f"Error verifying pickled file from {filename}\n"
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) "-----> !!!! The file is most likely corrupted !!!! <-----\n"
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
exc_info=True,
)
return None return None
except Exception: except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) errors.report(
print(traceback.format_exc(), file=sys.stderr) f"Error verifying pickled file from {filename}\n"
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) f"The file may be malicious, so the program is not going to read it.\n"
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
exc_info=True,
)
return None return None
return unsafe_torch_load(filename, *args, **kwargs) return unsafe_torch_load(filename, *args, **kwargs)
...@@ -190,4 +194,3 @@ with safe.Extra(handler): ...@@ -190,4 +194,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load unsafe_torch_load = torch.load
torch.load = load torch.load = load
global_extra_handler = None global_extra_handler = None
import sys
import traceback
from collections import namedtuple
import inspect import inspect
import os
from collections import namedtuple
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
class ImageSaveParams: class ImageSaveParams:
...@@ -111,6 +111,7 @@ callback_map = dict( ...@@ -111,6 +111,7 @@ callback_map = dict(
callbacks_before_ui=[], callbacks_before_ui=[],
callbacks_on_reload=[], callbacks_on_reload=[],
callbacks_list_optimizers=[], callbacks_list_optimizers=[],
callbacks_list_unets=[],
) )
...@@ -123,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): ...@@ -123,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in callback_map['callbacks_app_started']:
try: try:
c.callback(demo, app) c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script))
except Exception: except Exception:
report_exception(c, 'app_started_callback') report_exception(c, 'app_started_callback')
...@@ -271,16 +273,28 @@ def list_optimizers_callback(): ...@@ -271,16 +273,28 @@ def list_optimizers_callback():
return res return res
def list_unets_callback():
res = []
for c in callback_map['callbacks_list_unets']:
try:
c.callback(res)
except Exception:
report_exception(c, 'list_unets')
return res
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
if filename == 'unknown file': if filename == 'unknown file':
return return
for callback_list in callback_map.values(): for callback_list in callback_map.values():
...@@ -430,3 +444,10 @@ def on_list_optimizers(callback): ...@@ -430,3 +444,10 @@ def on_list_optimizers(callback):
to it.""" to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback) add_callback(callback_map['callbacks_list_optimizers'], callback)
def on_list_unets(callback):
"""register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback)
import os import os
import sys
import traceback
import importlib.util import importlib.util
from modules import errors
def load_module(path): def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
...@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): ...@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser) module.preload(parser)
except Exception: except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr) errors.report(f"Error running preload() for {preload_script}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
This diff is collapsed.
...@@ -3,7 +3,7 @@ from torch.nn.functional import silu ...@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType from types import MethodType
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
...@@ -43,7 +43,7 @@ def list_optimizers(): ...@@ -43,7 +43,7 @@ def list_optimizers():
optimizers.extend(new_optimizers) optimizers.extend(new_optimizers)
def apply_optimizations(): def apply_optimizations(option=None):
global current_optimizer global current_optimizer
undo_optimizations() undo_optimizations()
...@@ -60,7 +60,7 @@ def apply_optimizations(): ...@@ -60,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo() current_optimizer.undo()
current_optimizer = None current_optimizer = None
selection = shared.opts.cross_attention_optimization selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0: if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else: else:
...@@ -74,12 +74,13 @@ def apply_optimizations(): ...@@ -74,12 +74,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0] matching_optimizer = optimizers[0]
if matching_optimizer is not None: if matching_optimizer is not None:
print(f"Applying optimization: {matching_optimizer.name}... ", end='') print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply() matching_optimizer.apply()
print("done.") print("done.")
current_optimizer = matching_optimizer current_optimizer = matching_optimizer
return current_optimizer.name return current_optimizer.name
else: else:
print("Disabling attention optimization")
return '' return ''
...@@ -157,9 +158,9 @@ class StableDiffusionModelHijack: ...@@ -157,9 +158,9 @@ class StableDiffusionModelHijack:
def __init__(self): def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self): def apply_optimizations(self, option=None):
try: try:
self.optimization_method = apply_optimizations() self.optimization_method = apply_optimizations(option)
except Exception as e: except Exception as e:
errors.display(e, "applying cross attention optimization") errors.display(e, "applying cross attention optimization")
undo_optimizations() undo_optimizations()
...@@ -196,6 +197,11 @@ class StableDiffusionModelHijack: ...@@ -196,6 +197,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
...@@ -217,6 +223,8 @@ class StableDiffusionModelHijack: ...@@ -217,6 +223,8 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
...@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk.multipliers += [weight] * emb_len chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0: if chunk.tokens or not chunks:
next_chunk(is_last=True) next_chunk(is_last=True)
return chunks, token_count return chunks, token_count
......
...@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text ...@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if used_custom_terms:
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}") self.hijack.comments.append(f"Used embeddings: {embedding_names}")
......
from __future__ import annotations from __future__ import annotations
import math import math
import sys
import traceback
import psutil import psutil
import torch import torch
...@@ -48,7 +46,7 @@ class SdOptimizationXformers(SdOptimization): ...@@ -48,7 +46,7 @@ class SdOptimizationXformers(SdOptimization):
priority = 100 priority = 100
def is_available(self): def is_available(self):
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
...@@ -140,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: ...@@ -140,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops import xformers.ops
shared.xformers_available = True shared.xformers_available = True
except Exception: except Exception:
print("Cannot import xformers", file=sys.stderr) errors.report("Cannot import xformers", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
def get_available_vram(): def get_available_vram():
...@@ -605,7 +602,7 @@ def sdp_attnblock_forward(self, x): ...@@ -605,7 +602,7 @@ def sdp_attnblock_forward(self, x):
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k, v = q.float(), k.float(), v.float()
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
......
...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas ...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
...@@ -95,8 +95,7 @@ except Exception: ...@@ -95,8 +95,7 @@ except Exception:
def setup_model(): def setup_model():
if not os.path.exists(model_path): os.makedirs(model_path, exist_ok=True)
os.makedirs(model_path)
enable_midas_autodownload() enable_midas_autodownload()
...@@ -164,6 +163,7 @@ def model_hash(filename): ...@@ -164,6 +163,7 @@ def model_hash(filename):
def select_checkpoint(): def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
...@@ -171,14 +171,14 @@ def select_checkpoint(): ...@@ -171,14 +171,14 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None: if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
print(f" - directory {model_path}", file=sys.stderr) error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None: if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr) error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
exit(1) raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values())) checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None: if model_checkpoint is not None:
...@@ -421,7 +421,7 @@ class SdModelData: ...@@ -421,7 +421,7 @@ class SdModelData:
try: try:
load_model() load_model()
except Exception as e: except Exception as e:
errors.display(e, "loading stable diffusion model") errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr) print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr) print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None self.sd_model = None
...@@ -506,6 +506,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -506,6 +506,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks") timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
timer.record("calculate empty prompt")
print(f"Model loaded in {timer.summary()}.") print(f"Model loaded in {timer.summary()}.")
return sd_model return sd_model
...@@ -525,6 +530,8 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -525,6 +530,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
sd_unet.apply_unet("None")
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
......
...@@ -20,7 +20,7 @@ samplers_k_diffusion = [ ...@@ -20,7 +20,7 @@ samplers_k_diffusion = [
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True, 'discard_next_to_last_sigma': True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
...@@ -29,7 +29,7 @@ samplers_k_diffusion = [ ...@@ -29,7 +29,7 @@ samplers_k_diffusion = [
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True, 'discard_next_to_last_sigma': True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
...@@ -44,6 +44,14 @@ sampler_extra_params = { ...@@ -44,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
} }
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {
'Automatic': None,
'karras': k_diffusion.sampling.get_sigmas_karras,
'exponential': k_diffusion.sampling.get_sigmas_exponential,
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
}
class CFGDenoiser(torch.nn.Module): class CFGDenoiser(torch.nn.Module):
""" """
...@@ -61,6 +69,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -61,6 +69,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
self.image_cfg_scale = None self.image_cfg_scale = None
self.padded_cond_uncond = False
def combine_denoised(self, x_out, conds_list, uncond, cond_scale): def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
...@@ -125,6 +134,18 @@ class CFGDenoiser(torch.nn.Module): ...@@ -125,6 +134,18 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size] x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size] sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond: if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model: if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond]) cond_in = torch.cat([tensor, uncond, uncond])
...@@ -255,6 +276,13 @@ class KDiffusionSampler: ...@@ -255,6 +276,13 @@ class KDiffusionSampler:
try: try:
return func() return func()
except RecursionError:
print(
'Encountered RecursionError during sampling, returning last latent. '
'rho >5 with a polyexponential scheduler may cause this error. '
'You should try to use a smaller rho value instead.'
)
return self.last_latent
except sd_samplers_common.InterruptedException: except sd_samplers_common.InterruptedException:
return self.last_latent return self.last_latent
...@@ -294,6 +322,31 @@ class KDiffusionSampler: ...@@ -294,6 +322,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif opts.k_sched_type != "Automatic":
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
sigmas_kwargs = {
'sigma_min': sigma_min,
'sigma_max': sigma_max,
}
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
p.extra_generation_params["Schedule type"] = opts.k_sched_type
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
sigmas_kwargs['sigma_min'] = opts.sigma_min
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
sigmas_kwargs['sigma_max'] = opts.sigma_max
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
sigmas_kwargs['rho'] = opts.rho
p.extra_generation_params["Schedule rho"] = opts.rho
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
...@@ -355,6 +408,9 @@ class KDiffusionSampler: ...@@ -355,6 +408,9 @@ class KDiffusionSampler:
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
...@@ -388,5 +444,8 @@ class KDiffusionSampler: ...@@ -388,5 +444,8 @@ class KDiffusionSampler:
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples return samples
import torch.nn
import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
def list_unets():
new_unets = script_callbacks.list_unets_callback()
unet_options.clear()
unet_options.extend(new_unets)
def get_unet_option(option=None):
option = option or shared.opts.sd_unet
if option == "None":
return None
if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name
options = [x for x in unet_options if x.model_name == name]
option = options[0].label if options else "None"
return next(iter([x for x in unet_options if x.label == option]), None)
def apply_unet(option=None):
global current_unet_option
global current_unet
new_option = get_unet_option(option)
if new_option == current_unet_option:
return
if current_unet is not None:
print(f"Dectivating unet: {current_unet.option.label}")
current_unet.deactivate()
current_unet_option = new_option
if current_unet_option is None:
current_unet = None
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
shared.sd_model.model.diffusion_model.to(devices.device)
return
shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc()
current_unet = current_unet_option.create_unet()
current_unet.option = current_unet_option
print(f"Activating unet: {current_unet.option.label}")
current_unet.activate()
class SdUnetOption:
model_name = None
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
label = None
"""name of the unet in UI"""
def create_unet(self):
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
raise NotImplementedError()
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs):
raise NotImplementedError()
def activate(self):
pass
def deactivate(self):
pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
This diff is collapsed.
...@@ -29,3 +29,41 @@ def cross_attention_optimizations(): ...@@ -29,3 +29,41 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
def sd_unet_items():
import modules.sd_unet
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
def refresh_unet_list():
import modules.sd_unet
modules.sd_unet.list_unets()
ui_reorder_categories_builtin_items = [
"inpaint",
"sampler",
"checkboxes",
"hires_fix",
"dimensions",
"cfg",
"seed",
"batch",
"override_settings",
]
def ui_reorder_categories():
from modules import scripts
yield from ui_reorder_categories_builtin_items
sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
if isinstance(script.section, str):
sections[script.section] = 1
yield from sections
yield "scripts"
import csv import csv
import os import os
import os.path import os.path
import re
import typing import typing
import shutil import shutil
...@@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles): ...@@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles):
return prompt return prompt
re_spaces = re.compile(" +")
def extract_style_text_from_prompt(style_text, prompt):
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
if "{prompt}" in stripped_style_text:
left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
return True, prompt
else:
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
if prompt.endswith(', '):
prompt = prompt[:-2]
return True, prompt
return False, prompt
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
if not match_positive:
return False, prompt, negative_prompt
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
if not match_negative:
return False, prompt, negative_prompt
return True, extracted_positive, extracted_negative
class StyleDatabase: class StyleDatabase:
def __init__(self, path: str): def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "") self.no_style = PromptStyle("None", "", "")
...@@ -67,10 +106,34 @@ class StyleDatabase: ...@@ -67,10 +106,34 @@ class StyleDatabase:
if os.path.exists(path): if os.path.exists(path):
shutil.copy(path, f"{path}.bak") shutil.copy(path, f"{path}.bak")
fd = os.open(path, os.O_RDWR|os.O_CREAT) fd = os.open(path, os.O_RDWR | os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader() writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items()) writer.writerows(style._asdict() for k, style in self.styles.items())
def extract_styles_from_prompt(self, prompt, negative_prompt):
extracted = []
applicable_styles = list(self.styles.values())
while True:
found_style = None
for style in applicable_styles:
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
if is_match:
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
break
if not found_style:
break
applicable_styles.remove(found_style)
extracted.append(found_style.name)
return list(reversed(extracted)), prompt, negative_prompt
import json
import os
import sys
import traceback
import platform
import hashlib
import pkg_resources
import psutil
import re
import launch
from modules import paths_internal, timer
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = {
"GIT",
"INDEX_URL",
"WEBUI_LAUNCH_LIVE_OUTPUT",
"GRADIO_ANALYTICS_ENABLED",
"PYTHONPATH",
"TORCH_INDEX_URL",
"TORCH_COMMAND",
"REQS_FILE",
"XFORMERS_PACKAGE",
"GFPGAN_PACKAGE",
"CLIP_PACKAGE",
"OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO",
"CODEFORMER_REPO",
"BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH",
"CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS",
}
def pretty_bytes(num, suffix="B"):
for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
if abs(num) < 1024 or unit == 'Y':
return f"{num:.0f}{unit}{suffix}"
num /= 1024
def get():
res = get_dict()
text = json.dumps(res, ensure_ascii=False, indent=4)
h = hashlib.sha256(text.encode("utf8"))
text = text.replace(checksum_token, h.hexdigest())
return text
re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
def check(x):
m = re.search(re_checksum, x)
if not m:
return False
replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
h = hashlib.sha256(replaced.encode("utf8"))
return h.hexdigest() == m.group(1)
def get_dict():
ram = psutil.virtual_memory()
res = {
"Platform": platform.platform(),
"Python": platform.python_version(),
"Version": launch.git_tag(),
"Commit": launch.commit_hash(),
"Script path": paths_internal.script_path,
"Data path": paths_internal.data_path,
"Extensions dir": paths_internal.extensions_dir,
"Checksum": checksum_token,
"Commandline": sys.argv,
"Torch env info": get_torch_sysinfo(),
"Exceptions": get_exceptions(),
"CPU": {
"model": platform.processor(),
"count logical": psutil.cpu_count(logical=True),
"count physical": psutil.cpu_count(logical=False),
},
"RAM": {
x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
},
"Extensions": get_extensions(enabled=True),
"Inactive extensions": get_extensions(enabled=False),
"Environment": get_environment(),
"Config": get_config(),
"Startup": timer.startup_record,
"Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
}
return res
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def get_exceptions():
try:
from modules import errors
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
except Exception as e:
return str(e)
def get_environment():
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
re_newline = re.compile(r"\r*\n")
def get_torch_sysinfo():
try:
import torch.utils.collect_env
info = torch.utils.collect_env.get_env_info()._asdict()
return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
except Exception as e:
return str(e)
def get_extensions(*, enabled):
try:
from modules import extensions
def to_json(x: extensions.Extension):
return {
"name": x.name,
"path": x.path,
"version": x.version,
"branch": x.branch,
"remote": x.remote,
}
return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
except Exception as e:
return str(e)
def get_config():
try:
from modules import shared
return shared.opts.data
except Exception as e:
return str(e)
...@@ -77,27 +77,27 @@ def focal_point(im, settings): ...@@ -77,27 +77,27 @@ def focal_point(im, settings):
pois = [] pois = []
weight_pref_total = 0 weight_pref_total = 0
if len(corner_points) > 0: if corner_points:
weight_pref_total += settings.corner_points_weight weight_pref_total += settings.corner_points_weight
if len(entropy_points) > 0: if entropy_points:
weight_pref_total += settings.entropy_points_weight weight_pref_total += settings.entropy_points_weight
if len(face_points) > 0: if face_points:
weight_pref_total += settings.face_points_weight weight_pref_total += settings.face_points_weight
corner_centroid = None corner_centroid = None
if len(corner_points) > 0: if corner_points:
corner_centroid = centroid(corner_points) corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid) pois.append(corner_centroid)
entropy_centroid = None entropy_centroid = None
if len(entropy_points) > 0: if entropy_points:
entropy_centroid = centroid(entropy_points) entropy_centroid = centroid(entropy_points)
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
pois.append(entropy_centroid) pois.append(entropy_centroid)
face_centroid = None face_centroid = None
if len(face_points) > 0: if face_points:
face_centroid = centroid(face_points) face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid) pois.append(face_centroid)
...@@ -187,7 +187,7 @@ def image_face_points(im, settings): ...@@ -187,7 +187,7 @@ def image_face_points(im, settings):
except Exception: except Exception:
continue continue
if len(faces) > 0: if faces:
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return [] return []
...@@ -298,8 +298,7 @@ def download_and_cache_models(dirname): ...@@ -298,8 +298,7 @@ def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx' model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True)
os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name) cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file): if not os.path.exists(cache_file):
......
...@@ -32,7 +32,7 @@ class DatasetEntry: ...@@ -32,7 +32,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
......
...@@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti ...@@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption += shared.interrogator.generate_caption(image) caption += shared.interrogator.generate_caption(image)
if params.process_caption_deepbooru: if params.process_caption_deepbooru:
if len(caption) > 0: if caption:
caption += ", " caption += ", "
caption += deepbooru.model.tag_multi(image) caption += deepbooru.model.tag_multi(image)
...@@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti ...@@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption = caption.strip() caption = caption.strip()
if len(caption) > 0: if caption:
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption) file.write(caption)
......
import os import os
import sys
import traceback
from collections import namedtuple from collections import namedtuple
import torch import torch
...@@ -14,7 +12,7 @@ import numpy as np ...@@ -14,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -120,16 +118,29 @@ class EmbeddingDatabase: ...@@ -120,16 +118,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear() self.embedding_dirs.clear()
def register_embedding(self, embedding, model): def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding return self.register_embedding_by_name(embedding, model, embedding.name)
ids = model.cond_stage_model.tokenize([embedding.name])[0]
def register_embedding_by_name(self, embedding, model, name):
ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = []
if name in self.word_embeddings:
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) # remove old one from the lookup list
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
else:
lookup = self.ids_lookup[first_id]
if embedding is not None:
lookup += [(ids, embedding)]
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
if embedding is None:
# unregister embedding with specified name
if name in self.word_embeddings:
del self.word_embeddings[name]
if len(self.ids_lookup[first_id])==0:
del self.ids_lookup[first_id]
return None
self.word_embeddings[name] = embedding
return embedding return embedding
def get_expected_shape(self): def get_expected_shape(self):
...@@ -207,8 +218,7 @@ class EmbeddingDatabase: ...@@ -207,8 +218,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn) self.load_from_file(fullfn, fn)
except Exception: except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr) errors.report(f"Error loading embedding {fn}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
continue continue
def load_textual_inversion_embeddings(self, force_reload=False): def load_textual_inversion_embeddings(self, force_reload=False):
...@@ -241,7 +251,7 @@ class EmbeddingDatabase: ...@@ -241,7 +251,7 @@ class EmbeddingDatabase:
if self.previously_displayed_embeddings != displayed_embeddings: if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0: if self.skipped_embeddings:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
...@@ -632,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -632,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) errors.report("Error training embedding", exc_info=True)
pass
finally: finally:
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
......
import time import time
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
self.timer.record(self.category)
class Timer: class Timer:
def __init__(self): def __init__(self):
self.start = time.time() self.start = time.time()
self.records = {} self.records = {}
self.total = 0 self.total = 0
self.base_category = ''
def elapsed(self): def elapsed(self):
end = time.time() end = time.time()
...@@ -13,18 +32,29 @@ class Timer: ...@@ -13,18 +32,29 @@ class Timer:
self.start = end self.start = end
return res return res
def record(self, category, extra_time=0): def add_time_to_record(self, category, amount):
e = self.elapsed()
if category not in self.records: if category not in self.records:
self.records[category] = 0 self.records[category] = 0
self.records[category] += e + extra_time self.records[category] += amount
def record(self, category, extra_time=0):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time self.total += e + extra_time
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self): def summary(self):
res = f"{self.total:.1f}s" res = f"{self.total:.1f}s"
additions = [x for x in self.records.items() if x[1] >= 0.1] additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
if not additions: if not additions:
return res return res
...@@ -34,5 +64,13 @@ class Timer: ...@@ -34,5 +64,13 @@ class Timer:
return res return res
def dump(self):
return {'total': self.total, 'records': self.records}
def reset(self): def reset(self):
self.__init__() self.__init__()
startup_timer = Timer()
startup_record = None
This diff is collapsed.
...@@ -10,8 +10,11 @@ import subprocess as sp ...@@ -10,8 +10,11 @@ import subprocess as sp
from modules import call_queue, shared from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
import modules.images import modules.images
from modules.ui_components import ToolButton
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
def update_generation_info(generation_info, html_info, img_index): def update_generation_info(generation_info, html_info, img_index):
...@@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index):
save_to_dirs = shared.opts.use_save_to_dirs_for_ui save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format extension: str = shared.opts.samples_format
start_index = 0 start_index = 0
only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
only_one = True
images = [images[index]] images = [images[index]]
start_index = index start_index = index
...@@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index):
is_grid = image_index < p.index_of_first_image is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image) i = 0 if is_grid else (image_index - p.index_of_first_image)
p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path) filename = os.path.relpath(fullfn, path)
...@@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index):
# Make Zip # Make Zip
if do_make_zip: if do_make_zip:
zip_filepath = os.path.join(path, "images.zip") zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file: with ZipFile(zip_filepath, "w") as zip_file:
...@@ -211,3 +219,23 @@ Requested path was: {f} ...@@ -211,3 +219,23 @@ Requested path was: {f}
)) ))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button
This diff is collapsed.
...@@ -4,6 +4,7 @@ from pathlib import Path ...@@ -4,6 +4,7 @@ from pathlib import Path
from modules import shared from modules import shared
from modules.images import read_info_from_image, save_image_with_geninfo from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr import gradio as gr
import json import json
import html import html
...@@ -185,6 +186,8 @@ class ExtraNetworksPage: ...@@ -185,6 +186,8 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never": if search_only and shared.opts.extra_networks_hidden_models == "Never":
return "" return ""
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
args = { args = {
"background_image": background_image, "background_image": background_image,
"style": f"'display: none; {height}{width}'", "style": f"'display: none; {height}{width}'",
...@@ -198,10 +201,23 @@ class ExtraNetworksPage: ...@@ -198,10 +201,23 @@ class ExtraNetworksPage:
"search_term": item.get("search_term", ""), "search_term": item.get("search_term", ""),
"metadata_button": metadata_button, "metadata_button": metadata_button,
"search_only": " search_only" if search_only else "", "search_only": " search_only" if search_only else "",
"sort_keys": sort_keys,
} }
return self.card_page.format(**args) return self.card_page.format(**args)
def get_sort_keys(self, path):
"""
List of default keys used for sorting in the UI.
"""
pth = Path(path)
stat = pth.stat()
return {
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(),
}
def find_preview(self, path): def find_preview(self, path):
""" """
Find a preview PNG for a given path (without extension) and call link_preview on it. Find a preview PNG for a given path (without extension) and call link_preview on it.
...@@ -296,6 +312,8 @@ def create_ui(container, button, tabname): ...@@ -296,6 +312,8 @@ def create_ui(container, button, tabname):
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
......
...@@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
checkpoint: sd_models.CheckpointInfo checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items(): for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
yield { yield {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
...@@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment