Commit ec9bbda3 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into img2img-batch-png-info

parents c4c63dd5 18256c5f
...@@ -50,13 +50,14 @@ module.exports = { ...@@ -50,13 +50,14 @@ module.exports = {
globals: { globals: {
//script.js //script.js
gradioApp: "readonly", gradioApp: "readonly",
executeCallbacks: "readonly",
onAfterUiUpdate: "readonly",
onOptionsChanged: "readonly",
onUiLoaded: "readonly", onUiLoaded: "readonly",
onUiUpdate: "readonly", onUiUpdate: "readonly",
onOptionsChanged: "readonly",
uiCurrentTab: "writable", uiCurrentTab: "writable",
uiElementIsVisible: "readonly",
uiElementInSight: "readonly", uiElementInSight: "readonly",
executeCallbacks: "readonly", uiElementIsVisible: "readonly",
//ui.js //ui.js
opts: "writable", opts: "writable",
all_gallery_buttons: "readonly", all_gallery_buttons: "readonly",
...@@ -84,5 +85,7 @@ module.exports = { ...@@ -84,5 +85,7 @@ module.exports = {
// imageviewer.js // imageviewer.js
modalPrevImage: "readonly", modalPrevImage: "readonly",
modalNextImage: "readonly", modalNextImage: "readonly",
// token-counters.js
setupTokenCounters: "readonly",
} }
}; };
...@@ -43,8 +43,8 @@ body: ...@@ -43,8 +43,8 @@ body:
- type: input - type: input
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Version or Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
...@@ -80,6 +80,23 @@ body: ...@@ -80,6 +80,23 @@ body:
- AMD GPUs (RX 5000 below) - AMD GPUs (RX 5000 below)
- CPU - CPU
- Other GPUs - Other GPUs
- type: dropdown
id: cross_attention_opt
attributes:
label: Cross attention optimization
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
multiple: false
options:
- Automatic
- xformers
- sdp-no-mem
- sdp
- Doggettx
- V1
- InvokeAI
- "None "
validations:
required: true
- type: dropdown - type: dropdown
id: browsers id: browsers
attributes: attributes:
......
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.265 run: pip install ruff==0.0.272
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:
......
...@@ -42,7 +42,7 @@ jobs: ...@@ -42,7 +42,7 @@ jobs:
--no-half --no-half
--disable-opt-split-attention --disable-opt-split-attention
--use-cpu all --use-cpu all
--add-stop-route --api-server-stop
2>&1 | tee output.txt & 2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: | run: |
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage - name: Show coverage
run: | run: |
python -m coverage combine .coverage* python -m coverage combine .coverage*
......
## 1.4.0
### Features:
* zoom controls for inpainting
* run basic torch calculation at startup in parallel to reduce the performance impact of first generation
* option to pad prompt/neg prompt to be same length
* remove taming_transformers dependency
* custom k-diffusion scheduler settings
* add an option to show selected settings in main txt2img/img2img UI
* sysinfo tab in settings
* infer styles from prompts when pasting params into the UI
* an option to control the behavior of the above
### Minor:
* bump Gradio to 3.32.0
* bump xformers to 0.0.20
* Add option to disable token counters
* tooltip fixes & optimizations
* make it possible to configure filename for the zip download
* `[vae_filename]` pattern for filenames
* Revert discarding penultimate sigma for DPM-Solver++(2M) SDE
* change UI reorder setting to multiselect
* read version info form CHANGELOG.md if git version info is not available
* link footer API to Wiki when API is not active
* persistent conds cache (opt-in optimization)
### Extensions:
* After installing extensions, webui properly restarts the process rather than reloads the UI
* Added VAE listing to web API. Via: /sdapi/v1/sd-vae
* custom unet support
* Add onAfterUiUpdate callback
* refactor EmbeddingDatabase.register_embedding() to allow unregistering
* add before_process callback for scripts
* add ability for alwayson scripts to specify section and let user reorder those sections
### Bug Fixes:
* Fix dragging text to prompt
* fix incorrect quoting for infotext values with colon in them
* fix "hires. fix" prompt sharing same labels with txt2img_prompt
* Fix s_min_uncond default type int
* Fix for #10643 (Inpainting mask sometimes not working)
* fix bad styling for thumbs view in extra networks #10639
* fix for empty list of optimizations #10605
* small fixes to prepare_tcmalloc for Debian/Ubuntu compatibility
* fix --ui-debug-mode exit
* patch GitPython to not use leaky persistent processes
* fix duplicate Cross attention optimization after UI reload
* torch.cuda.is_available() check for SdOptimizationXformers
* fix hires fix using wrong conds in second pass if using Loras.
* handle exception when parsing generation parameters from png info
* fix upcast attention dtype error
* forcing Torch Version to 1.13.1 for RX 5000 series GPUs
* split mask blur into X and Y components, patch Outpainting MK2 accordingly
* don't die when a LoRA is a broken symlink
* allow activation of Generate Forever during generation
## 1.3.2 ## 1.3.2
### Bug Fixes: ### Bug Fixes:
......
import os import os
import sys
import traceback
from basicsr.utils.download_util import load_file_from_url
from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401
...@@ -45,22 +42,17 @@ class UpscalerLDSR(Upscaler): ...@@ -45,22 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path): if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path model = local_safetensors_path
else: else:
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
try: yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
return LDSR(model, yaml)
except Exception: return LDSR(model, yaml)
print("Error importing LDSR:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def do_upscale(self, img, path): def do_upscale(self, img, path):
ldsr = self.load_model(path) try:
if ldsr is None: ldsr = self.load_model(path)
print("NO LDSR!") except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)
......
...@@ -10,7 +10,7 @@ from contextlib import contextmanager ...@@ -10,7 +10,7 @@ from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): ...@@ -91,8 +91,9 @@ class VQModel(pl.LightningModule):
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
......
...@@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): ...@@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
......
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
# where the license is as follows:
#
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE./
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits is False, "Only for interface compatible with Gumbel"
assert return_logits is False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
...@@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
assert len(params.items) > 0 assert params.items
names.append(params.items[0]) names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
......
...@@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): ...@@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk):
else: else:
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0: if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora return lora
...@@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): ...@@ -267,7 +267,7 @@ def load_loras(names, multipliers=None):
lora.multiplier = multipliers[i] if multipliers else 1.0 lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora) loaded_loras.append(lora)
if len(failed_to_load_loras) > 0: if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
...@@ -448,7 +448,11 @@ def list_available_loras(): ...@@ -448,7 +448,11 @@ def list_available_loras():
continue continue
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename) try:
entry = LoraOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
continue
available_loras[name] = entry available_loras[name] = entry
......
...@@ -13,7 +13,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -13,7 +13,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
lora.list_available_loras() lora.list_available_loras()
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias() alias = lora_on_disk.get_alias()
...@@ -27,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -27,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
import os.path
import sys import sys
import traceback
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
...@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = [] scalers = []
add_model2 = True add_model2 = True
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
...@@ -38,8 +36,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -38,8 +36,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
except Exception: except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr) errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
if add_model2: if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2) scalers.append(scaler_data2)
...@@ -90,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -90,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) try:
if model is None: model = self.load_model(selected_file)
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
...@@ -120,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -120,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if "http" in path: if path.startswith("http"):
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for _, v in model.named_parameters(): for _, v in model.named_parameters():
......
import os import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') device_swinir = devices.get_device_for('swinir')
...@@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir') ...@@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "SwinIR" self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ self.model_url = SWINIR_MODEL_URL
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
self.user_path = dirname self.user_path = dirname
super().__init__() super().__init__()
scalers = [] scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"]) model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files: for model in model_files:
if "http" in model: if model.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(model) name = modelloader.friendly_name(model)
...@@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img, model_file):
model = self.load_model(model_file) try:
if model is None: model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
...@@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler): ...@@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler):
return img return img
def load_model(self, path, scale=4): def load_model(self, path, scale=4):
if "http" in path: if path.startswith("http"):
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") filename = modelloader.load_file_from_url(
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) url=path,
model_dir=self.model_download_path,
file_name=f"{self.model_name.replace(' ', '_')}.pth",
)
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"): if filename.endswith(".v2.pth"):
model = net2( model = Swin2SR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
window_size=8, window_size=8,
img_range=1.0, img_range=1.0,
depths=[6, 6, 6, 6, 6, 6], depths=[6, 6, 6, 6, 6, 6],
embed_dim=180, embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, mlp_ratio=2,
upsampler="nearest+conv", upsampler="nearest+conv",
resi_connection="1conv", resi_connection="1conv",
) )
params = None params = None
else: else:
model = net( model = SwinIR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
......
This diff is collapsed.
import gradio as gr
from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
.canvas-tooltip-info {
position: absolute;
top: 10px;
left: 10px;
cursor: help;
background-color: rgba(0, 0, 0, 0.3);
width: 20px;
height: 20px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
flex-direction: column;
z-index: 100;
}
.canvas-tooltip-info::after {
content: '';
display: block;
width: 2px;
height: 7px;
background-color: white;
margin-top: 2px;
}
.canvas-tooltip-info::before {
content: '';
display: block;
width: 2px;
height: 2px;
background-color: white;
}
.canvas-tooltip-content {
display: none;
background-color: #f9f9f9;
color: #333;
border: 1px solid #ddd;
padding: 15px;
position: absolute;
top: 40px;
left: 10px;
width: 250px;
font-size: 16px;
opacity: 0;
border-radius: 8px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 100;
}
.canvas-tooltip:hover .canvas-tooltip-content {
display: block;
animation: fadeIn 0.5s;
opacity: 1;
}
@keyframes fadeIn {
from {opacity: 0;}
to {opacity: 1;}
}
import gradio as gr
from modules import scripts, shared, ui_components, ui_settings
from modules.ui_components import FormColumn
class ExtraOptionsSection(scripts.Script):
section = "extra_options"
def __init__(self):
self.comps = None
self.setting_names = None
def title(self):
return "Extra options"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
self.comps.append(comp)
self.setting_names.append(setting_name)
def get_settings_values():
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
return self.comps
def before_process(self, p, *args):
for name, value in zip(self.setting_names, args):
if name not in p.override_settings:
p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
}))
<div class='card' style={style} onclick={card_clicked}> <div class='card' style={style} onclick={card_clicked} {sort_keys}>
{background_image} {background_image}
{metadata_button} {metadata_button}
<div class='actions'> <div class='actions'>
......
<div> <div>
<a href="/docs">API</a> <a href="{api_docs}">API</a>
 •   • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a> <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 •   • 
<a href="https://gradio.app">Gradio</a> <a href="https://gradio.app">Gradio</a>
 •   • 
<a href="#" onclick="showProfile('./internal/profile-startup'); return false;">Startup profile</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a> <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div> </div>
<br /> <br />
......
...@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) { ...@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) {
} }
onUiUpdate(function() { onAfterUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (arPreviewRect) { if (arPreviewRect) {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
......
...@@ -148,12 +148,18 @@ var addContextMenuEventListener = initResponse[2]; ...@@ -148,12 +148,18 @@ var addContextMenuEventListener = initResponse[2];
500); 500);
}; };
appendContextMenuOption('#txt2img_generate', 'Generate forever', function() { let generateOnRepeat_txt2img = function() {
generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
}); };
appendContextMenuOption('#img2img_generate', 'Generate forever', function() {
let generateOnRepeat_img2img = function() {
generateOnRepeat('#img2img_generate', '#img2img_interrupt'); generateOnRepeat('#img2img_generate', '#img2img_interrupt');
}); };
appendContextMenuOption('#txt2img_generate', 'Generate forever', generateOnRepeat_txt2img);
appendContextMenuOption('#txt2img_interrupt', 'Generate forever', generateOnRepeat_txt2img);
appendContextMenuOption('#img2img_generate', 'Generate forever', generateOnRepeat_img2img);
appendContextMenuOption('#img2img_interrupt', 'Generate forever', generateOnRepeat_img2img);
let cancelGenerateForever = function() { let cancelGenerateForever = function() {
clearInterval(window.generateOnRepeatInterval); clearInterval(window.generateOnRepeatInterval);
...@@ -167,6 +173,4 @@ var addContextMenuEventListener = initResponse[2]; ...@@ -167,6 +173,4 @@ var addContextMenuEventListener = initResponse[2];
})(); })();
//End example Context Menu Items //End example Context Menu Items
onUiUpdate(function() { onAfterUiUpdate(addContextMenuEventListener);
addContextMenuEventListener();
});
...@@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) { ...@@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) {
} }
} }
function eventHasFiles(e) {
if (!e.dataTransfer || !e.dataTransfer.files) return false;
if (e.dataTransfer.files.length > 0) return true;
if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
return false;
}
function dragDropTargetIsPrompt(target) {
if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
return false;
}
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); if (!eventHasFiles(e)) return;
if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; var targetImage = target.closest('[data-testid="image"]');
} if (!dragDropTargetIsPrompt(target) && !targetImage) return;
e.stopPropagation(); e.stopPropagation();
e.preventDefault(); e.preventDefault();
e.dataTransfer.dropEffect = 'copy'; e.dataTransfer.dropEffect = 'copy';
...@@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => { ...@@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) { if (!eventHasFiles(e)) return;
return;
if (dragDropTargetIsPrompt(target)) {
e.stopPropagation();
e.preventDefault();
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
} }
const imgWrap = target.closest('[data-testid="image"]');
if (!imgWrap) { var targetImage = target.closest('[data-testid="image"]');
if (targetImage) {
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage(targetImage, files);
return; return;
} }
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage(imgWrap, files);
}); });
window.addEventListener('paste', e => { window.addEventListener('paste', e => {
......
...@@ -100,11 +100,12 @@ function keyupEditAttention(event) { ...@@ -100,11 +100,12 @@ function keyupEditAttention(event) {
if (String(weight).length == 1) weight += ".0"; if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); var endParenPos = text.substring(selectionEnd).indexOf(')');
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
selectionStart--; selectionStart--;
selectionEnd--; selectionEnd--;
} else { } else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
} }
target.focus(); target.focus();
......
...@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) ...@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
} }
return [confirmed, config_state_name, config_restore_type]; return [confirmed, config_state_name, config_restore_type];
} }
function toggle_all_extensions(event) {
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
checkbox_el.checked = event.target.checked;
});
}
function toggle_extension() {
let all_extensions_toggled = true;
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
if (!checkbox_el.checked) {
all_extensions_toggled = false;
break;
}
}
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
}
...@@ -3,10 +3,17 @@ function setupExtraNetworksForTab(tabname) { ...@@ -3,10 +3,17 @@ function setupExtraNetworksForTab(tabname) {
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
var sort = gradioApp().getElementById(tabname + '_extra_sort');
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
search.classList.add('search'); search.classList.add('search');
sort.classList.add('sort');
sortOrder.classList.add('sortorder');
sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(search); tabs.appendChild(search);
tabs.appendChild(sort);
tabs.appendChild(sortOrder);
tabs.appendChild(refresh); tabs.appendChild(refresh);
var applyFilter = function() { var applyFilter = function() {
...@@ -26,8 +33,51 @@ function setupExtraNetworksForTab(tabname) { ...@@ -26,8 +33,51 @@ function setupExtraNetworksForTab(tabname) {
}); });
}; };
var applySort = function() {
var reverse = sortOrder.classList.contains("sortReverse");
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
return;
}
sort.dataset.sortkey = sortKeyStore;
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) {
card.originalParentElement = card.parentElement;
});
var sortedCards = Array.from(cards);
sortedCards.sort(function(cardA, cardB) {
var a = cardA.dataset[sortKey];
var b = cardB.dataset[sortKey];
if (!isNaN(a) && !isNaN(b)) {
return parseInt(a) - parseInt(b);
}
return (a < b ? -1 : (a > b ? 1 : 0));
});
if (reverse) {
sortedCards.reverse();
}
cards.forEach(function(card) {
card.remove();
});
sortedCards.forEach(function(card) {
card.originalParentElement.appendChild(card);
});
};
search.addEventListener("input", applyFilter); search.addEventListener("input", applyFilter);
applyFilter(); applyFilter();
["change", "blur", "click"].forEach(function(evt) {
sort.querySelector("input").addEventListener(evt, applySort);
});
sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse");
applySort();
});
extraNetworksApplyFilter[tabname] = applyFilter; extraNetworksApplyFilter[tabname] = applyFilter;
} }
......
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined; let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function() { onAfterUiUpdate(function() {
if (!txt2img_gallery) { if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img"); txt2img_gallery = attachGalleryListeners("txt2img");
} }
......
...@@ -15,7 +15,7 @@ var titles = { ...@@ -15,7 +15,7 @@ var titles = {
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomized",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
...@@ -112,21 +112,29 @@ var titles = { ...@@ -112,21 +112,29 @@ var titles = {
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
}; };
function updateTooltipForSpan(span) { function updateTooltip(element) {
if (span.title) return; // already has a title if (element.title) return; // already has a title
let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; let text = element.textContent;
let tooltip = localization[titles[text]] || titles[text];
if (!tooltip) { if (!tooltip) {
tooltip = localization[titles[span.value]] || titles[span.value]; let value = element.value;
if (value) tooltip = localization[titles[value]] || titles[value];
} }
if (!tooltip) { if (!tooltip) {
for (const c of span.classList) { // Gradio dropdown options have `data-value`.
let dataValue = element.dataset.value;
if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
}
if (!tooltip) {
for (const c of element.classList) {
if (c in titles) { if (c in titles) {
tooltip = localization[titles[c]] || titles[c]; tooltip = localization[titles[c]] || titles[c];
break; break;
...@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) { ...@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
} }
if (tooltip) { if (tooltip) {
span.title = tooltip; element.title = tooltip;
} }
} }
function updateTooltipForSelect(select) { // Nodes to check for adding tooltips.
if (select.onchange != null) return; const tooltipCheckNodes = new Set();
// Timer for debouncing tooltip check.
let tooltipCheckTimer = null;
select.onchange = function() { function processTooltipCheckNodes() {
select.title = localization[titles[select.value]] || titles[select.value] || ""; for (const node of tooltipCheckNodes) {
}; updateTooltip(node);
}
tooltipCheckNodes.clear();
} }
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1}; onUiUpdate(function(mutationRecords) {
for (const record of mutationRecords) {
onUiUpdate(function(m) { if (record.type === "childList" && record.target.classList.contains("options")) {
m.forEach(function(record) { // This smells like a Gradio dropdown menu having changed,
record.addedNodes.forEach(function(node) { // so let's enqueue an update for the input element that shows the current value.
if (observedTooltipElements[node.tagName]) { let wrap = record.target.parentNode;
updateTooltipForSpan(node); let input = wrap?.querySelector("input");
} if (input) {
if (node.tagName == "SELECT") { input.title = ""; // So we'll even have a chance to update it.
updateTooltipForSelect(node); tooltipCheckNodes.add(input);
} }
}
if (node.querySelectorAll) { for (const node of record.addedNodes) {
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan); if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
node.querySelectorAll('select').forEach(updateTooltipForSelect); if (!node.title) {
if (
node.tagName === "SPAN" ||
node.tagName === "BUTTON" ||
node.tagName === "P" ||
node.tagName === "INPUT" ||
(node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
) {
tooltipCheckNodes.add(node);
}
}
node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
} }
}); }
}); }
if (tooltipCheckNodes.size) {
clearTimeout(tooltipCheckTimer);
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
}
}); });
...@@ -39,5 +39,5 @@ function imageMaskResize() { ...@@ -39,5 +39,5 @@ function imageMaskResize() {
}); });
} }
onUiUpdate(imageMaskResize); onAfterUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize); window.addEventListener('resize', imageMaskResize);
window.onload = (function() {
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
e.stopPropagation();
e.preventDefault();
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
});
});
...@@ -170,7 +170,7 @@ function modalTileImageToggle(event) { ...@@ -170,7 +170,7 @@ function modalTileImageToggle(event) {
event.stopPropagation(); event.stopPropagation();
} }
onUiUpdate(function() { onAfterUiUpdate(function() {
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img'); var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox); fullImg_preview.forEach(setupImageForLightbox);
......
let gamepads = [];
window.addEventListener('gamepadconnected', (e) => { window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index; const index = e.gamepad.index;
let isWaiting = false; let isWaiting = false;
setInterval(async() => { gamepads[index] = setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return; if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index]; const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0]; const xValue = gamepad.axes[0];
...@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => { ...@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
}, 10); }, 10);
}); });
window.addEventListener('gamepaddisconnected', (e) => {
clearInterval(gamepads[e.gamepad.index]);
});
/* /*
Primarily for vr controller type pointer devices. Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr. I use the wheel event because there's currently no way to do it properly with web xr.
......
...@@ -4,7 +4,7 @@ let lastHeadImg = null; ...@@ -4,7 +4,7 @@ let lastHeadImg = null;
let notificationButton = null; let notificationButton = null;
onUiUpdate(function() { onAfterUiUpdate(function() {
if (notificationButton == null) { if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications'); notificationButton = gradioApp().getElementById('request_notifications');
......
function createRow(table, cellName, items) {
var tr = document.createElement('tr');
var res = [];
items.forEach(function(x, i) {
if (x === undefined) {
res.push(null);
return;
}
var td = document.createElement(cellName);
td.textContent = x;
tr.appendChild(td);
res.push(td);
var colspan = 1;
for (var n = i + 1; n < items.length; n++) {
if (items[n] !== undefined) {
break;
}
colspan += 1;
}
if (colspan > 1) {
td.colSpan = colspan;
}
});
table.appendChild(tr);
return res;
}
function showProfile(path, cutoff = 0.05) {
requestGet(path, {}, function(data) {
var table = document.createElement('table');
table.className = 'popup-table';
data.records['total'] = data.total;
var keys = Object.keys(data.records).sort(function(a, b) {
return data.records[b] - data.records[a];
});
var items = keys.map(function(x) {
return {key: x, parts: x.split('/'), time: data.records[x]};
});
var maxLength = items.reduce(function(a, b) {
return Math.max(a, b.parts.length);
}, 0);
var cols = createRow(table, 'th', ['record', 'seconds']);
cols[0].colSpan = maxLength;
function arraysEqual(a, b) {
return !(a < b || b < a);
}
var addLevel = function(level, parent, hide) {
var matching = items.filter(function(x) {
return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
});
var sorted = matching.sort(function(a, b) {
return b.time - a.time;
});
var othersTime = 0;
var othersList = [];
var othersRows = [];
var childrenRows = [];
sorted.forEach(function(x) {
var visible = x.time >= cutoff && !hide;
var cells = [];
for (var i = 0; i < maxLength; i++) {
cells.push(x.parts[i]);
}
cells.push(x.time.toFixed(3));
var cols = createRow(table, 'td', cells);
for (i = 0; i < level; i++) {
cols[i].className = 'muted';
}
var tr = cols[0].parentNode;
if (!visible) {
tr.classList.add("hidden");
}
if (x.time >= cutoff) {
childrenRows.push(tr);
} else {
othersTime += x.time;
othersList.push(x.parts[level]);
othersRows.push(tr);
}
var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);
if (children.length > 0) {
var cell = cols[level];
var onclick = function() {
cell.classList.remove("link");
cell.removeEventListener("click", onclick);
children.forEach(function(x) {
x.classList.remove("hidden");
});
};
cell.classList.add("link");
cell.addEventListener("click", onclick);
}
});
if (othersTime > 0) {
var cells = [];
for (var i = 0; i < maxLength; i++) {
cells.push(parent[i]);
}
cells.push(othersTime.toFixed(3));
cells[level] = 'others';
var cols = createRow(table, 'td', cells);
for (i = 0; i < level; i++) {
cols[i].className = 'muted';
}
var cell = cols[level];
var tr = cell.parentNode;
var onclick = function() {
tr.classList.add("hidden");
cell.classList.remove("link");
cell.removeEventListener("click", onclick);
othersRows.forEach(function(x) {
x.classList.remove("hidden");
});
};
cell.title = othersList.join(", ");
cell.classList.add("link");
cell.addEventListener("click", onclick);
if (hide) {
tr.classList.add("hidden");
}
childrenRows.push(tr);
}
return childrenRows;
};
addLevel(0, []);
popup(table);
});
}
let promptTokenCountDebounceTime = 800;
let promptTokenCountTimeouts = {};
var promptTokenCountUpdateFunctions = {};
function update_txt2img_tokens(...args) {
// Called from Gradio
update_token_counter("txt2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_img2img_tokens(...args) {
// Called from Gradio
update_token_counter("img2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_token_counter(button_id) {
if (opts.disable_token_counters) {
return;
}
if (promptTokenCountTimeouts[button_id]) {
clearTimeout(promptTokenCountTimeouts[button_id]);
}
promptTokenCountTimeouts[button_id] = setTimeout(
() => gradioApp().getElementById(button_id)?.click(),
promptTokenCountDebounceTime,
);
}
function recalculatePromptTokens(name) {
promptTokenCountUpdateFunctions[name]?.();
}
function recalculate_prompts_txt2img() {
// Called from Gradio
recalculatePromptTokens('txt2img_prompt');
recalculatePromptTokens('txt2img_neg_prompt');
return Array.from(arguments);
}
function recalculate_prompts_img2img() {
// Called from Gradio
recalculatePromptTokens('img2img_prompt');
recalculatePromptTokens('img2img_neg_prompt');
return Array.from(arguments);
}
function setupTokenCounting(id, id_counter, id_button) {
var prompt = gradioApp().getElementById(id);
var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
if (opts.disable_token_counters) {
counter.style.display = "none";
return;
}
if (counter.parentElement == prompt.parentElement) {
return;
}
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
promptTokenCountUpdateFunctions[id] = function() {
update_token_counter(id_button);
};
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
}
function setupTokenCounters() {
setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
}
...@@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) { ...@@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) {
} }
var promptTokecountUpdateFuncs = {};
function recalculatePromptTokens(name) {
if (promptTokecountUpdateFuncs[name]) {
promptTokecountUpdateFuncs[name]();
}
}
function recalculate_prompts_txt2img() {
recalculatePromptTokens('txt2img_prompt');
recalculatePromptTokens('txt2img_neg_prompt');
return Array.from(arguments);
}
function recalculate_prompts_img2img() {
recalculatePromptTokens('img2img_prompt');
recalculatePromptTokens('img2img_neg_prompt');
return Array.from(arguments);
}
var opts = {}; var opts = {};
onUiUpdate(function() { onAfterUiUpdate(function() {
if (Object.keys(opts).length != 0) return; if (Object.keys(opts).length != 0) return;
var json_elem = gradioApp().getElementById('settings_json'); var json_elem = gradioApp().getElementById('settings_json');
...@@ -302,28 +281,7 @@ onUiUpdate(function() { ...@@ -302,28 +281,7 @@ onUiUpdate(function() {
json_elem.parentElement.style.display = "none"; json_elem.parentElement.style.display = "none";
function registerTextarea(id, id_counter, id_button) { setupTokenCounters();
var prompt = gradioApp().getElementById(id);
var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if (counter.parentElement == prompt.parentElement) {
return;
}
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
promptTokecountUpdateFuncs[id] = function() {
update_token_counter(id_button);
};
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
var settings_tabs = gradioApp().querySelector('#settings div'); var settings_tabs = gradioApp().querySelector('#settings div');
...@@ -354,33 +312,6 @@ onOptionsChanged(function() { ...@@ -354,33 +312,6 @@ onOptionsChanged(function() {
}); });
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800;
let token_timeouts = {};
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button");
if (args.length == 2) {
return args[0];
}
return args;
}
function update_img2img_tokens(...args) {
update_token_counter(
"img2img_token_button"
);
if (args.length == 2) {
return args[0];
}
return args;
}
function update_token_counter(button_id) {
if (token_timeouts[button_id]) {
clearTimeout(token_timeouts[button_id]);
}
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
function restart_reload() { function restart_reload() {
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
......
...@@ -42,7 +42,7 @@ onOptionsChanged(function() { ...@@ -42,7 +42,7 @@ onOptionsChanged(function() {
function settingsHintsShowQuicksettings() { function settingsHintsShowQuicksettings() {
requestGet("./internal/quicksettings-hint", {}, function(data) { requestGet("./internal/quicksettings-hint", {}, function(data) {
var table = document.createElement('table'); var table = document.createElement('table');
table.className = 'settings-value-table'; table.className = 'popup-table';
data.forEach(function(obj) { data.forEach(function(obj) {
var tr = document.createElement('tr'); var tr = document.createElement('tr');
......
This diff is collapsed.
...@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel): ...@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel):
model_url: Optional[str] = Field(title="URL") model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale") scale: Optional[float] = Field(title="Scale")
class LatentUpscalerModeItem(BaseModel):
name: str = Field(title="Name")
class SDModelItem(BaseModel): class SDModelItem(BaseModel):
title: str = Field(title="Title") title: str = Field(title="Title")
model_name: str = Field(title="Model Name") model_name: str = Field(title="Model Name")
...@@ -249,6 +252,10 @@ class SDModelItem(BaseModel): ...@@ -249,6 +252,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class SDVaeItem(BaseModel):
model_name: str = Field(title="Model Name")
filename: str = Field(title="Filename")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
path: Optional[str] = Field(title="Path") path: Optional[str] = Field(title="Path")
...@@ -267,10 +274,6 @@ class PromptStyleItem(BaseModel): ...@@ -267,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt") prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
class EmbeddingItem(BaseModel): class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
......
from functools import wraps
import html import html
import sys
import threading import threading
import traceback
import time import time
from modules import shared, progress from modules import shared, progress, errors
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -20,17 +19,18 @@ def wrap_queued_call(func): ...@@ -20,17 +19,18 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
@wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
id_task = args[0] id_task = args[0]
progress.add_task_to_queue(id_task) progress.add_task_to_queue(id_task)
else: else:
id_task = None id_task = None
with queue_lock: with queue_lock:
shared.state.begin() shared.state.begin(job=id_task)
progress.start_task(id_task) progress.start_task(id_task)
try: try:
...@@ -47,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -47,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:
...@@ -56,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -56,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try: try:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text # When printing out our debug argument list,
max_debug_str_len = 131072 # (1024*1024)/8 # do not print out more than a 100 KB of text
max_debug_str_len = 131072
print("Error completing request", file=sys.stderr) message = "Error completing request"
argStr = f"Arguments: {args} {kwargs}" arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
print(argStr[:max_debug_str_len], file=sys.stderr) if len(arg_str) > max_debug_str_len:
if len(argStr) > max_debug_str_len: arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) errors.report(f"{message}\n{arg_str}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
...@@ -108,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -108,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res) return tuple(res)
return f return f
...@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la ...@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly") parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed") parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
...@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch ...@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
import os import os
import sys
import traceback
import cv2 import cv2
import torch import torch
import modules.face_restoration import modules.face_restoration
import modules.shared import modules.shared
from modules import shared, devices, modelloader from modules import shared, devices, modelloader, errors
from modules.paths import models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes # codeformer people made a choice to include modified basicsr library to their project which makes
...@@ -17,14 +15,11 @@ model_dir = "Codeformer" ...@@ -17,14 +15,11 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None codeformer = None
def setup_model(dirname): def setup_model(dirname):
global model_path os.makedirs(model_path, exist_ok=True)
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
...@@ -105,8 +100,8 @@ def setup_model(dirname): ...@@ -105,8 +100,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as error: except Exception:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8') restored_face = restored_face.astype('uint8')
...@@ -127,15 +122,11 @@ def setup_model(dirname): ...@@ -127,15 +122,11 @@ def setup_model(dirname):
return restored_img return restored_img
global have_codeformer
have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
print("Error setting up CodeFormer:", file=sys.stderr) errors.report("Error setting up CodeFormer", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
# sys.path = stored_sys_path # sys.path = stored_sys_path
...@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c ...@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
""" """
import os import os
import sys
import traceback
import json import json
import time import time
import tqdm import tqdm
...@@ -13,7 +11,7 @@ from datetime import datetime ...@@ -13,7 +11,7 @@ from datetime import datetime
from collections import OrderedDict from collections import OrderedDict
import git import git
from modules import shared, extensions from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir from modules.paths_internal import script_path, config_states_dir
...@@ -53,8 +51,7 @@ def get_webui_config(): ...@@ -53,8 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")): if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path) webui_repo = git.Repo(script_path)
except Exception: except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr) errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
webui_remote = None webui_remote = None
webui_commit_hash = None webui_commit_hash = None
...@@ -134,8 +131,7 @@ def restore_webui_config(config): ...@@ -134,8 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")): if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path) webui_repo = git.Repo(script_path)
except Exception: except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr) errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return return
try: try:
...@@ -143,8 +139,7 @@ def restore_webui_config(config): ...@@ -143,8 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True) webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.") print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception: except Exception:
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) errors.report(f"Error restoring webui to commit{webui_commit_hash}")
print(traceback.format_exc(), file=sys.stderr)
def restore_extension_config(config): def restore_extension_config(config):
......
import sys import sys
import contextlib import contextlib
from functools import lru_cache
import torch import torch
from modules import errors from modules import errors
...@@ -13,13 +15,6 @@ def has_mps() -> bool: ...@@ -13,13 +15,6 @@ def has_mps() -> bool:
else: else:
return mac_specific.has_mps return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string(): def get_cuda_device_string():
from modules import shared from modules import shared
...@@ -154,3 +149,19 @@ def test_for_nans(x, where): ...@@ -154,3 +149,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check." message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
import sys import sys
import textwrap
import traceback import traceback
exception_records = []
def record_exception():
_, e, tb = sys.exc_info()
if e is None:
return
if exception_records and exception_records[-1] == e:
return
exception_records.append((e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
record_exception()
for line in message.splitlines():
print("***", line, file=sys.stderr)
if exc_info:
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
print("---", file=sys.stderr)
def print_error_explanation(message): def print_error_explanation(message):
record_exception()
lines = message.strip().split("\n") lines = message.strip().split("\n")
max_len = max([len(x) for x in lines]) max_len = max([len(x) for x in lines])
...@@ -12,9 +46,15 @@ def print_error_explanation(message): ...@@ -12,9 +46,15 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr) print('=' * max_len, file=sys.stderr)
def display(e: Exception, task): def display(e: Exception, task, *, full_traceback=False):
record_exception()
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) te = traceback.TracebackException.from_exception(e)
if full_traceback:
# include frames leading up to the try-catch block
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
print(*te.format(), sep="", file=sys.stderr)
message = str(e) message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
...@@ -28,6 +68,8 @@ already_displayed = {} ...@@ -28,6 +68,8 @@ already_displayed = {}
def display_once(e: Exception, task): def display_once(e: Exception, task):
record_exception()
if task in already_displayed: if task in already_displayed:
return return
......
import os import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict): def mod2normal(state_dict):
...@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler): ...@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
...@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler): ...@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data) self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
model = self.load_model(selected_model) try:
if model is None: model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
def load_model(self, path: str): def load_model(self, path: str):
if "http" in path: if path.startswith("http"):
filename = load_file_from_url( # TODO: this doesn't use `path` at all?
filename = modelloader.load_file_from_url(
url=self.model_url, url=self.model_url,
model_dir=self.model_download_path, model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth", file_name=f"{self.model_name}.pth",
progress=True,
) )
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
......
import os import os
import sys
import threading import threading
import traceback
import git from modules import shared, errors
from modules.gitpython_hack import Repo
from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = [] extensions = []
if not os.path.exists(extensions_dir): os.makedirs(extensions_dir, exist_ok=True)
os.makedirs(extensions_dir)
def active(): def active():
...@@ -54,10 +50,9 @@ class Extension: ...@@ -54,10 +50,9 @@ class Extension:
repo = None repo = None
try: try:
if os.path.exists(os.path.join(self.path, ".git")): if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(self.path) repo = Repo(self.path)
except Exception: except Exception:
print(f"Error reading github repository info from {self.path}:", file=sys.stderr) errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
...@@ -72,8 +67,8 @@ class Extension: ...@@ -72,8 +67,8 @@ class Extension:
self.commit_hash = commit.hexsha self.commit_hash = commit.hexsha
self.version = self.commit_hash[:8] self.version = self.commit_hash[:8]
except Exception as ex: except Exception:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None self.remote = None
self.have_info_from_repo = True self.have_info_from_repo = True
...@@ -94,7 +89,7 @@ class Extension: ...@@ -94,7 +89,7 @@ class Extension:
return res return res
def check_updates(self): def check_updates(self):
repo = git.Repo(self.path) repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True): for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
...@@ -116,7 +111,7 @@ class Extension: ...@@ -116,7 +111,7 @@ class Extension:
self.status = "latest" self.status = "latest"
def fetch_and_reset_hard(self, commit='origin'): def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path) repo = Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True) repo.git.fetch(all=True)
......
...@@ -32,6 +32,9 @@ class ExtraNetworkParams: ...@@ -32,6 +32,9 @@ class ExtraNetworkParams:
else: else:
self.positional.append(item) self.positional.append(item)
def __eq__(self, other):
return self.items == other.items
class ExtraNetwork: class ExtraNetwork:
def __init__(self, name): def __init__(self, name):
...@@ -100,6 +103,9 @@ def activate(p, extra_network_data): ...@@ -100,6 +103,9 @@ def activate(p, extra_network_data):
except Exception as e: except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}") errors.display(e, f"activating extra network {extra_network_name}")
if p.scripts is not None:
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
def deactivate(p, extra_network_data): def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call """call deactivate for extra networks in extra_network_data in specified order, then call
......
...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
assert len(params.items) > 0 assert params.items
names.append(params.items[0]) names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
......
...@@ -73,8 +73,7 @@ def to_half(tensor, enable): ...@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin() shared.state.begin(job="model-merge")
shared.state.job = 'model-merge'
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message
......
...@@ -55,7 +55,7 @@ def image_from_url_text(filedata): ...@@ -55,7 +55,7 @@ def image_from_url_text(filedata):
if filedata is None: if filedata is None:
return None return None
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0] filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False): if type(filedata) == dict and filedata.get("is_file", False):
...@@ -174,31 +174,6 @@ def send_image_and_dimensions(x): ...@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h return img, w, h
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
If the infotext has no hash, then a hypernet with the same name will be selected instead.
"""
hypernet_name = hypernet_name.lower()
if hypernet_hash is not None:
# Try to match the hash in the name
for hypernet_key in shared.hypernetworks.keys():
result = re_hypernet_hash.search(hypernet_key)
if result is not None and result[1] == hypernet_hash:
return hypernet_key
else:
# Fall back to a hypernet with the same name
for hypernet_key in shared.hypernetworks.keys():
if hypernet_key.lower().startswith(hypernet_name):
return hypernet_key
return None
def restore_old_hires_fix_params(res): def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into """for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale""" width, height, and hr scale"""
...@@ -265,19 +240,30 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -265,19 +240,30 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
prompt += ("" if prompt == "" else "\n") + line prompt += ("" if prompt == "" else "\n") + line
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
if shared.opts.infotext_styles == "Apply":
res["Styles array"] = found_styles
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
res["Styles array"] = found_styles
res["Prompt"] = prompt res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
if v[0] == '"' and v[-1] == '"': try:
v = unquote(v) if v[0] == '"' and v[-1] == '"':
v = unquote(v)
m = re_imagesize.match(v)
if m is not None: m = re_imagesize.match(v)
res[f"{k}-1"] = m.group(1) if m is not None:
res[f"{k}-2"] = m.group(2) res[f"{k}-1"] = m.group(1)
else: res[f"{k}-2"] = m.group(2)
res[k] = v else:
res[k] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# Missing CLIP skip means it was set to 1 (the default) # Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res: if "Clip skip" not in res:
...@@ -306,11 +292,19 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -306,11 +292,19 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res: if "RNG" not in res:
res["RNG"] = "GPU" res["RNG"] = "GPU"
return res if "Schedule type" not in res:
res["Schedule type"] = "Automatic"
if "Schedule max sigma" not in res:
res["Schedule max sigma"] = 0
settings_map = {} if "Schedule min sigma" not in res:
res["Schedule min sigma"] = 0
if "Schedule rho" not in res:
res["Schedule rho"] = 0
return res
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
...@@ -318,6 +312,10 @@ infotext_to_setting_name_mapping = [ ...@@ -318,6 +312,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'), ('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'), ('ENSD', 'eta_noise_seed_delta'),
('Schedule type', 'k_sched_type'),
('Schedule max sigma', 'sigma_max'),
('Schedule min sigma', 'sigma_min'),
('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'), ('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'), ('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'), ('Eta DDIM', 'eta_ddim'),
...@@ -330,6 +328,7 @@ infotext_to_setting_name_mapping = [ ...@@ -330,6 +328,7 @@ infotext_to_setting_name_mapping = [
('Token merging ratio hr', 'token_merging_ratio_hr'), ('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
] ]
...@@ -421,7 +420,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -421,7 +420,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
vals_pairs = [f"{k}: {v}" for k, v in vals.items()] vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
paste_fields = paste_fields + [(override_settings_component, paste_settings)] paste_fields = paste_fields + [(override_settings_component, paste_settings)]
...@@ -438,5 +437,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -438,5 +437,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[], outputs=[],
show_progress=False, show_progress=False,
) )
import os import os
import sys
import traceback
import facexlib import facexlib
import gfpgan import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import paths, shared, devices, modelloader from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
...@@ -27,7 +25,7 @@ def gfpgann(): ...@@ -27,7 +25,7 @@ def gfpgann():
return None return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]: if len(models) == 1 and models[0].startswith("http"):
model_file = models[0] model_file = models[0]
elif len(models) != 0: elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) latest_file = max(models, key=os.path.getctime)
...@@ -72,11 +70,8 @@ gfpgan_constructor = None ...@@ -72,11 +70,8 @@ gfpgan_constructor = None
def setup_model(dirname): def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
try: try:
os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer from gfpgan import GFPGANer
from facexlib import detection, parsing # noqa: F401 from facexlib import detection, parsing # noqa: F401
global user_path global user_path
...@@ -112,5 +107,4 @@ def setup_model(dirname): ...@@ -112,5 +107,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN()) shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
print("Error setting up GFPGAN:", file=sys.stderr) errors.report("Error setting up GFPGAN", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
from __future__ import annotations
import io
import subprocess
import git
class Git(git.Git):
"""
Git subclassed to never use persistent processes.
"""
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
ret = subprocess.check_output(
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
input=self._prepare_ref(ref),
cwd=self._working_dir,
timeout=2,
)
return self._parse_object_header(ret)
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
# Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects).
ret = subprocess.check_output(
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
input=self._prepare_ref(ref),
cwd=self._working_dir,
timeout=30,
)
bio = io.BytesIO(ret)
hexsha, typename, size = self._parse_object_header(bio.readline())
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
class Repo(git.Repo):
GitCommandWrapperType = Git
...@@ -2,8 +2,6 @@ import datetime ...@@ -2,8 +2,6 @@ import datetime
import glob import glob
import html import html
import os import os
import sys
import traceback
import inspect import inspect
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
...@@ -11,7 +9,7 @@ import torch ...@@ -11,7 +9,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -325,17 +323,14 @@ def load_hypernetwork(name): ...@@ -325,17 +323,14 @@ def load_hypernetwork(name):
if path is None: if path is None:
return None return None
hypernetwork = Hypernetwork()
try: try:
hypernetwork = Hypernetwork()
hypernetwork.load(path) hypernetwork.load(path)
return hypernetwork
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) errors.report(f"Error loading hypernetwork {path}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return None return None
return hypernetwork
def load_hypernetworks(names, multipliers=None): def load_hypernetworks(names, multipliers=None):
already_loaded = {} already_loaded = {}
...@@ -358,17 +353,6 @@ def load_hypernetworks(names, multipliers=None): ...@@ -358,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork) shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
...@@ -451,18 +435,6 @@ def statistics(data): ...@@ -451,18 +435,6 @@ def statistics(data):
return total_information, recent_information return total_information, recent_information
def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
try:
print("Loss statistics for file " + key)
info, recent = statistics(list(loss_info[key]))
print(info)
print(recent)
except Exception as e:
print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
...@@ -770,12 +742,11 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -770,12 +742,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) errors.report("Exception in training hypernetwork", exc_info=True)
finally: finally:
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
hypernetwork.eval() hypernetwork.eval()
#report_statistics(loss_dict)
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()
......
This diff is collapsed.
import os import os
from pathlib import Path
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
import gradio as gr
from modules import sd_samplers, images as imgutil from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
...@@ -13,7 +15,7 @@ from modules.ui import plaintext_to_html ...@@ -13,7 +15,7 @@ from modules.ui import plaintext_to_html
import modules.scripts import modules.scripts
def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, output_dir, inpaint_mask_dir, args): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p) processing.fix_seed(p)
images = shared.listfiles(input_dir) images = shared.listfiles(input_dir)
...@@ -21,9 +23,10 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp ...@@ -21,9 +23,10 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp
is_inpaint_batch = False is_inpaint_batch = False
if inpaint_mask_dir: if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir) inpaint_masks = shared.listfiles(inpaint_mask_dir)
is_inpaint_batch = len(inpaint_masks) > 0 is_inpaint_batch = bool(inpaint_masks)
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
...@@ -57,14 +60,31 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp ...@@ -57,14 +60,31 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp
continue continue
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)
if to_scale:
p.width = int(img.width * scale_by)
p.height = int(img.height * scale_by)
p.init_images = [img] * p.batch_size p.init_images = [img] * p.batch_size
image_path = Path(image)
if is_inpaint_batch: if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching # try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) if len(inpaint_masks) == 1:
# if not found use first one ("same mask for all images" use-case)
if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0] mask_image_path = inpaint_masks[0]
else:
# try to find corresponding mask for an image using simple filename matching
mask_image_dir = Path(inpaint_mask_dir)
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
if len(masks_found) == 0:
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
continue
# it should contain only 1 matching mask
# otherwise user has many masks with the same name but different extensions
mask_image_path = masks_found[0]
mask_image = Image.open(mask_image_path) mask_image = Image.open(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image
...@@ -103,7 +123,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp ...@@ -103,7 +123,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp
proc = process_images(p) proc = process_images(p)
for n, processed_image in enumerate(proc.images): for n, processed_image in enumerate(proc.images):
filename = os.path.basename(image) filename = image_path.name
if n > 0: if n > 0:
left, right = os.path.splitext(filename) left, right = os.path.splitext(filename)
...@@ -116,7 +136,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp ...@@ -116,7 +136,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
...@@ -130,7 +150,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -130,7 +150,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB") image = image.convert("RGB")
elif mode == 3: # inpaint sketch elif mode == 3: # inpaint sketch
image = inpaint_color_sketch image = inpaint_color_sketch
...@@ -152,7 +173,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -152,7 +173,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None: if image is not None:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
if selected_scale_tab == 1: if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected" assert image, "Can't scale by because no image is selected"
width = int(image.width * scale_by) width = int(image.width * scale_by)
...@@ -198,6 +219,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -198,6 +219,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img p.scripts = modules.scripts.scripts_img2img
p.script_args = args p.script_args = args
p.user = request.username
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
...@@ -207,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -207,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
process_batch(p, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir)
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
else: else:
......
import os import os
import sys import sys
import traceback
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import re import re
...@@ -185,8 +184,7 @@ class InterrogateModels: ...@@ -185,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = "" res = ""
shared.state.begin() shared.state.begin(job="interrogate")
shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
...@@ -216,8 +214,7 @@ class InterrogateModels: ...@@ -216,8 +214,7 @@ class InterrogateModels:
res += f", {match}" res += f", {match}"
except Exception: except Exception:
print("Error interrogating", file=sys.stderr) errors.report("Error interrogating", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>" res += "<error>"
self.unload() self.unload()
......
...@@ -7,7 +7,7 @@ import platform ...@@ -7,7 +7,7 @@ import platform
import json import json
from functools import lru_cache from functools import lru_cache
from modules import cmd_args from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
...@@ -68,7 +68,13 @@ def git_tag(): ...@@ -68,7 +68,13 @@ def git_tag():
try: try:
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
return "<none>" try:
from pathlib import Path
changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
with changelog_md.open(encoding="utf-8") as file:
return next((line.strip() for line in file if line.strip()), "<none>")
except Exception:
return "<none>"
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
...@@ -136,15 +142,15 @@ def git_clone(url, dir, name, commithash=None): ...@@ -136,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None: if commithash is None:
return return
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash: if current_hash == commithash:
return return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
...@@ -188,7 +194,7 @@ def run_extension_installer(extension_dir): ...@@ -188,7 +194,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e: except Exception as e:
print(e, file=sys.stderr) errors.report(str(e))
def list_extensions(settings_file): def list_extensions(settings_file):
...@@ -198,8 +204,8 @@ def list_extensions(settings_file): ...@@ -198,8 +204,8 @@ def list_extensions(settings_file):
if os.path.isfile(settings_file): if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file: with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file) settings = json.load(file)
except Exception as e: except Exception:
print(e, file=sys.stderr) errors.report("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none') disable_all_extensions = settings.get('disable_all_extensions', 'none')
...@@ -223,23 +229,28 @@ def prepare_environment(): ...@@ -223,23 +229,28 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
except OSError:
pass
if not args.skip_python_version_check: if not args.skip_python_version_check:
check_python_version() check_python_version()
...@@ -286,7 +297,6 @@ def prepare_environment(): ...@@ -286,7 +297,6 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
......
import json import json
import os import os
import sys
import traceback
from modules import errors
localizations = {} localizations = {}
...@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: ...@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file: with open(fn, "r", encoding="utf8") as file:
data = json.load(file) data = json.load(file)
except Exception: except Exception:
print(f"Error loading localization from {fn}:", file=sys.stderr) errors.report(f"Error loading localization from {fn}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return f"window.localization = {json.dumps(data)}" return f"window.localization = {json.dumps(data)}"
...@@ -15,6 +15,8 @@ def send_everything_to_cpu(): ...@@ -15,6 +15,8 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
sd_model.lowvram = True
parents = {} parents = {}
def send_me_to_gpu(module, _): def send_me_to_gpu(module, _):
...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks: for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu) block.register_forward_pre_hook(send_me_to_gpu)
def is_enabled(sd_model):
return getattr(sd_model, 'lowvram', False)
...@@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc ...@@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# check `getattr` and try it for compatibility # use check `getattr` and try it for compatibility.
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool: def check_for_mps() -> bool:
if not getattr(torch, 'has_mps', False): if version.parse(torch.__version__) <= version.parse("2.0.1"):
return False if not getattr(torch, 'has_mps', False):
try: return False
torch.zeros(1).to(torch.device("mps")) try:
return True torch.zeros(1).to(torch.device("mps"))
except Exception: return True
return False except Exception:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps() has_mps = check_for_mps()
......
from __future__ import annotations
import os import os
import shutil import shutil
import importlib import importlib
...@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale ...@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress)
return cached_file
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
...@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
from basicsr.utils.download_util import load_file_from_url output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
dl = load_file_from_url(model_url, places[0], True, download_name)
output.append(dl)
else: else:
output.append(model_url) output.append(model_url)
...@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str): def friendly_name(file: str):
if "http" in file: if file.startswith("http"):
file = urlparse(file).path file = urlparse(file).path
file = os.path.basename(file) file = os.path.basename(file)
...@@ -95,8 +118,7 @@ def cleanup_models(): ...@@ -95,8 +118,7 @@ def cleanup_models():
def move_files(src_path: str, dest_path: str, ext_filter: str = None): def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try: try:
if not os.path.exists(dest_path): os.makedirs(dest_path, exist_ok=True)
os.makedirs(dest_path)
if os.path.exists(src_path): if os.path.exists(src_path):
for file in os.listdir(src_path): for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file) fullpath = os.path.join(src_path, file)
......
...@@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): ...@@ -230,9 +230,9 @@ class DDPM(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0: if missing:
print(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if unexpected:
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
......
...@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl ...@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
...@@ -39,17 +38,3 @@ for d, must_exist, what, options in path_dirs: ...@@ -39,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
class Prioritize:
def __init__(self, name):
self.name = name
self.path = None
def __enter__(self):
self.path = sys.path.copy()
sys.path = [paths[self.name]] + sys.path
def __exit__(self, exc_type, exc_val, exc_tb):
sys.path = self.path
self.path = None
...@@ -9,8 +9,7 @@ from modules.shared import opts ...@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin() shared.state.begin(job="extras")
shared.state.job = 'extras'
image_data = [] image_data = []
image_names = [] image_names = []
...@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
for image, name in zip(image_data, image_names): for image, name in zip(image_data, image_names):
shared.state.textinfo = name shared.state.textinfo = name
existing_pnginfo = image.info or {} parameters, existing_pnginfo = images.read_info_from_image(image)
if parameters:
existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
......
This diff is collapsed.
...@@ -336,11 +336,11 @@ def parse_prompt_attention(text): ...@@ -336,11 +336,11 @@ def parse_prompt_attention(text):
round_brackets.append(len(res)) round_brackets.append(len(res))
elif text == '[': elif text == '[':
square_brackets.append(len(res)) square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0: elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight)) multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0: elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier) multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0: elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier) multiply_range(square_brackets.pop(), square_bracket_multiplier)
else: else:
parts = re.split(re_break, text) parts = re.split(re_break, text)
......
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules import modelloader from modules import modelloader, errors
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
def __init__(self, path): def __init__(self, path):
...@@ -36,8 +34,7 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -36,8 +34,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception: except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr) errors.report("Error importing Real-ESRGAN", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False self.enable = False
self.scalers = [] self.scalers = []
...@@ -45,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -45,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable: if not self.enable:
return img return img
info = self.load_model(path) try:
if not os.path.exists(info.local_data_path): info = self.load_model(path)
print(f"Unable to load RealESRGAN model: {info.name}") except Exception:
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(
...@@ -65,21 +63,17 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -65,21 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image return image
def load_model(self, path): def load_model(self, path):
try: for scaler in self.scalers:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
if info is None: scaler.local_data_path = modelloader.load_file_from_url(
print(f"Unable to find model info: {path}") scaler.data_path,
return None model_dir=self.model_download_path,
)
if info.local_data_path.startswith("http"): if not os.path.exists(scaler.local_data_path):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
return info raise ValueError(f"Unable to find model info: {path}")
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def load_models(self, _): def load_models(self, _):
return get_realesrgan_models(self) return get_realesrgan_models(self)
...@@ -135,5 +129,4 @@ def get_realesrgan_models(scaler): ...@@ -135,5 +129,4 @@ def get_realesrgan_models(scaler):
] ]
return models return models
except Exception: except Exception:
print("Error making Real-ESRGAN models list:", file=sys.stderr) errors.report("Error making Real-ESRGAN models list", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
import os
from pathlib import Path
from modules.paths_internal import script_path
def is_restartable() -> bool:
"""
Return True if the webui is restartable (i.e. there is something watching to restart it with)
"""
return bool(os.environ.get('SD_WEBUI_RESTART'))
def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
(Path(script_path) / "tmp" / "restart").touch()
stop_program()
def stop_program() -> None:
os._exit(0)
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import pickle import pickle
import collections import collections
import sys
import traceback
import torch import torch
import numpy import numpy
...@@ -11,7 +9,10 @@ import _codecs ...@@ -11,7 +9,10 @@ import _codecs
import zipfile import zipfile
import re import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
from modules import errors
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args): def encode(*args):
...@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ...@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler) check_pt(filename, extra_handler)
except pickle.UnpicklingError: except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) errors.report(
print(traceback.format_exc(), file=sys.stderr) f"Error verifying pickled file from {filename}\n"
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) "-----> !!!! The file is most likely corrupted !!!! <-----\n"
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
exc_info=True,
)
return None return None
except Exception: except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) errors.report(
print(traceback.format_exc(), file=sys.stderr) f"Error verifying pickled file from {filename}\n"
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) f"The file may be malicious, so the program is not going to read it.\n"
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
exc_info=True,
)
return None return None
return unsafe_torch_load(filename, *args, **kwargs) return unsafe_torch_load(filename, *args, **kwargs)
...@@ -190,4 +194,3 @@ with safe.Extra(handler): ...@@ -190,4 +194,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load unsafe_torch_load = torch.load
torch.load = load torch.load = load
global_extra_handler = None global_extra_handler = None
import sys
import traceback
from collections import namedtuple
import inspect import inspect
import os
from collections import namedtuple
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
class ImageSaveParams: class ImageSaveParams:
...@@ -111,6 +111,7 @@ callback_map = dict( ...@@ -111,6 +111,7 @@ callback_map = dict(
callbacks_before_ui=[], callbacks_before_ui=[],
callbacks_on_reload=[], callbacks_on_reload=[],
callbacks_list_optimizers=[], callbacks_list_optimizers=[],
callbacks_list_unets=[],
) )
...@@ -123,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): ...@@ -123,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in callback_map['callbacks_app_started']:
try: try:
c.callback(demo, app) c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script))
except Exception: except Exception:
report_exception(c, 'app_started_callback') report_exception(c, 'app_started_callback')
...@@ -271,16 +273,28 @@ def list_optimizers_callback(): ...@@ -271,16 +273,28 @@ def list_optimizers_callback():
return res return res
def list_unets_callback():
res = []
for c in callback_map['callbacks_list_unets']:
try:
c.callback(res)
except Exception:
report_exception(c, 'list_unets')
return res
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
if filename == 'unknown file': if filename == 'unknown file':
return return
for callback_list in callback_map.values(): for callback_list in callback_map.values():
...@@ -430,3 +444,10 @@ def on_list_optimizers(callback): ...@@ -430,3 +444,10 @@ def on_list_optimizers(callback):
to it.""" to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback) add_callback(callback_map['callbacks_list_optimizers'], callback)
def on_list_unets(callback):
"""register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback)
import os import os
import sys
import traceback
import importlib.util import importlib.util
from modules import errors
def load_module(path): def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
...@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): ...@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser) module.preload(parser)
except Exception: except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr) errors.report(f"Error running preload() for {preload_script}", exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
This diff is collapsed.
...@@ -3,7 +3,7 @@ from torch.nn.functional import silu ...@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType from types import MethodType
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
...@@ -43,7 +43,7 @@ def list_optimizers(): ...@@ -43,7 +43,7 @@ def list_optimizers():
optimizers.extend(new_optimizers) optimizers.extend(new_optimizers)
def apply_optimizations(): def apply_optimizations(option=None):
global current_optimizer global current_optimizer
undo_optimizations() undo_optimizations()
...@@ -60,7 +60,7 @@ def apply_optimizations(): ...@@ -60,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo() current_optimizer.undo()
current_optimizer = None current_optimizer = None
selection = shared.opts.cross_attention_optimization selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0: if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else: else:
...@@ -74,12 +74,13 @@ def apply_optimizations(): ...@@ -74,12 +74,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0] matching_optimizer = optimizers[0]
if matching_optimizer is not None: if matching_optimizer is not None:
print(f"Applying optimization: {matching_optimizer.name}... ", end='') print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply() matching_optimizer.apply()
print("done.") print("done.")
current_optimizer = matching_optimizer current_optimizer = matching_optimizer
return current_optimizer.name return current_optimizer.name
else: else:
print("Disabling attention optimization")
return '' return ''
...@@ -157,9 +158,9 @@ class StableDiffusionModelHijack: ...@@ -157,9 +158,9 @@ class StableDiffusionModelHijack:
def __init__(self): def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self): def apply_optimizations(self, option=None):
try: try:
self.optimization_method = apply_optimizations() self.optimization_method = apply_optimizations(option)
except Exception as e: except Exception as e:
errors.display(e, "applying cross attention optimization") errors.display(e, "applying cross attention optimization")
undo_optimizations() undo_optimizations()
...@@ -196,6 +197,11 @@ class StableDiffusionModelHijack: ...@@ -196,6 +197,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
...@@ -217,6 +223,8 @@ class StableDiffusionModelHijack: ...@@ -217,6 +223,8 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
...@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk.multipliers += [weight] * emb_len chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0: if chunk.tokens or not chunks:
next_chunk(is_last=True) next_chunk(is_last=True)
return chunks, token_count return chunks, token_count
......
...@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text ...@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if used_custom_terms:
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}") self.hijack.comments.append(f"Used embeddings: {embedding_names}")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -32,7 +32,7 @@ class DatasetEntry: ...@@ -32,7 +32,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic ...@@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
...@@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
p.script_args = args p.script_args = args
p.user = request.username
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment