Commit 97e1cf69 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into master

parents 484948f5 bb431df5
extensions
extensions-disabled
repositories
venv
\ No newline at end of file
module.exports = {
env: {
browser: true,
es2021: true,
},
extends: "eslint:recommended",
parserOptions: {
ecmaVersion: "latest",
},
rules: {
"arrow-spacing": "error",
"block-spacing": "error",
"brace-style": "error",
"comma-dangle": ["error", "only-multiline"],
"comma-spacing": "error",
"comma-style": ["error", "last"],
"curly": ["error", "multi-line", "consistent"],
"eol-last": "error",
"func-call-spacing": "error",
"function-call-argument-newline": ["error", "consistent"],
"function-paren-newline": ["error", "consistent"],
"indent": ["error", 4],
"key-spacing": "error",
"keyword-spacing": "error",
"linebreak-style": ["error", "unix"],
"no-extra-semi": "error",
"no-mixed-spaces-and-tabs": "error",
"no-trailing-spaces": "error",
"no-whitespace-before-property": "error",
"object-curly-newline": ["error", {consistent: true, multiline: true}],
"quote-props": ["error", "consistent-as-needed"],
"semi": ["error", "always"],
"semi-spacing": "error",
"semi-style": ["error", "last"],
"space-before-blocks": "error",
"space-before-function-paren": ["error", "never"],
"space-in-parens": ["error", "never"],
"space-infix-ops": "error",
"space-unary-ops": "error",
"switch-colon-spacing": "error",
"template-curly-spacing": ["error", "never"],
"unicode-bom": "error",
"no-multi-spaces": "error",
"object-curly-spacing": ["error", "never"],
"operator-linebreak": ["error", "after"],
"no-unused-vars": "off",
"no-redeclare": "off",
},
globals: {
// this file
module: "writable",
//script.js
gradioApp: "writable",
onUiLoaded: "writable",
onUiUpdate: "writable",
onOptionsChanged: "writable",
uiCurrentTab: "writable",
uiElementIsVisible: "writable",
executeCallbacks: "writable",
//ui.js
opts: "writable",
all_gallery_buttons: "writable",
selected_gallery_button: "writable",
selected_gallery_index: "writable",
args_to_array: "writable",
switch_to_txt2img: "writable",
switch_to_img2img_tab: "writable",
switch_to_img2img: "writable",
switch_to_sketch: "writable",
switch_to_inpaint: "writable",
switch_to_inpaint_sketch: "writable",
switch_to_extras: "writable",
get_tab_index: "writable",
create_submit_args: "writable",
restart_reload: "writable",
updateInput: "writable",
//extraNetworks.js
requestGet: "writable",
popup: "writable",
// from python
localization: "writable",
// progrssbar.js
randomId: "writable",
requestProgress: "writable",
// imageviewer.js
modalPrevImage: "writable",
modalNextImage: "writable",
}
};
...@@ -47,6 +47,15 @@ body: ...@@ -47,6 +47,15 @@ body:
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
...@@ -59,6 +68,18 @@ body: ...@@ -59,6 +68,18 @@ body:
- iOS - iOS
- Android - Android
- Other/Cloud - Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown - type: dropdown
id: browsers id: browsers
attributes: attributes:
......
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests name: Run Linting/Formatting on Pull Requests
on: on:
- push - push
- pull_request - pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs: jobs:
lint: lint-python:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - uses: actions/setup-python@v4
uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.11
cache: pip # NB: there's no cache: pip here since we're not installing anything
cache-dependency-path: | # from the requirements.txt file(s) in the repository; it's faster
**/requirements*txt # not to have GHA download an (at the time of writing) 4 GB cache
- name: Install PyLint # of PyTorch and other dependencies.
run: | - name: Install Ruff
python -m pip install --upgrade pip run: pip install ruff==0.0.265
pip install pylint - name: Run Ruff
# This lets PyLint check to see if it can resolve imports run: ruff .
- name: Install dependencies lint-js:
run: | runs-on: ubuntu-latest
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" steps:
python launch.py - name: Checkout Code
- name: Analysing the code with pylint uses: actions/checkout@v3
run: | - name: Install Node.js
pylint $(git ls-files '*.py') uses: actions/setup-node@v3
with:
node-version: 18
- run: npm i --ci
- run: npm run lint
...@@ -17,8 +17,14 @@ jobs: ...@@ -17,8 +17,14 @@ jobs:
cache: pip cache: pip
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py
- name: Run tests - name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
- name: Upload main app stdout-stderr - name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()
......
...@@ -34,3 +34,5 @@ notification.mp3 ...@@ -34,3 +34,5 @@ notification.mp3
/test/stderr.txt /test/stderr.txt
/cache.json* /cache.json*
/config_states/ /config_states/
/node_modules
/package-lock.json
\ No newline at end of file
...@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab): ...@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
2. Run `update.bat`.
3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH". 1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
...@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -88,7 +88,7 @@ class LDSR: ...@@ -88,7 +88,7 @@ class LDSR:
x_t = None x_t = None
logs = None logs = None
for n in range(n_runs): for _ in range(n_runs):
if custom_shape is not None: if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
...@@ -110,7 +110,6 @@ class LDSR: ...@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps) diffusion_steps = int(steps)
eta = 1.0 eta = 1.0
down_sample_method = 'Lanczos'
gc.collect() gc.collect()
if torch.cuda.is_available: if torch.cuda.is_available:
...@@ -158,7 +157,7 @@ class LDSR: ...@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path): def get_cond(selected_path):
example = dict() example = {}
up_f = 4 up_f = 4
c = selected_path.convert('RGB') c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s ...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad() @torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict() log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize ...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except: except Exception:
pass pass
log["sample"] = x_sample log["sample"] = x_sample
......
...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
......
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
import ldm.models.autoencoder import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(self,
...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): ...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed, n_embed,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
image_key="image", image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): ...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): ...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): ...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=""):
...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): ...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): ...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema log["reconstructions_ema"] = xrec_ema
return log return log
...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): ...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def encode(self, x): def encode(self, x):
...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel): ...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel) ldm.models.autoencoder.VQModel = VQModel
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) ldm.models.autoencoder.VQModelInterface = VQModelInterface
...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): ...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
use_ema=True, use_ema=True,
...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): ...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): ...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd["state_dict"] sd = sd["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): ...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.first_stage_key) x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N) N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row) n_row = min(x.shape[0], n_row)
...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): ...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x log["inputs"] = x
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
x_start = x[:n_row] x_start = x[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): ...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", []) ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs) super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except Exception:
self.num_downs = 0 self.num_downs = 0
if not scale_by_std: if not scale_by_std:
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): ...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial) intermediates.append(x0_partial)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img) intermediates.append(img)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates: if return_intermediates:
return img, intermediates return img, intermediates
...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond, return self.p_sample_loop(cond,
...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
log = dict() log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
force_c_encode=True, force_c_encode=True,
...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows: if plot_diffusion_rows:
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
z_start = z[:n_row] z_start = z[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs) logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation' key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key] dset = self.trainer.datamodule.datasets[key]
...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
import glob
import os import os
import re import re
import torch import torch
...@@ -177,7 +176,7 @@ def load_lora(name, filename): ...@@ -177,7 +176,7 @@ def load_lora(name, filename):
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
...@@ -189,7 +188,7 @@ def load_lora(name, filename): ...@@ -189,7 +188,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0: if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
...@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None): ...@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None):
loaded_loras.clear() loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]): if any(x is None for x in loras_on_disk):
list_available_loras() list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
...@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu ...@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}') print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names) self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward): def lora_forward(module, input, original_forward):
...@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward): ...@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ()) self.lora_current_names = ()
setattr(self, "lora_weights_backup", None) self.lora_weights_backup = None
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
...@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params): ...@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params):
added = [] added = []
for k, v in params.items(): for k in params:
if not k.startswith("AddNet Model "): if not k.startswith("AddNet Model "):
continue continue
......
...@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) ...@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
})) }))
......
...@@ -10,10 +10,9 @@ from tqdm import tqdm ...@@ -10,10 +10,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for k, v in model.named_parameters(): for _, v in model.named_parameters():
v.requires_grad = False v.requires_grad = False
model = model.to(device) model = model.to(device)
return model return model
def on_ui_settings():
import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
script_callbacks.on_ui_settings(on_ui_settings)
...@@ -61,7 +61,9 @@ class WMSA(nn.Module): ...@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns: Returns:
output: tensor shape [b h w c] output: tensor shape [b h w c]
""" """
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1) h_windows = x.size(1)
w_windows = x.size(2) w_windows = x.size(2)
...@@ -85,8 +87,9 @@ class WMSA(nn.Module): ...@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output) output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), if self.type != 'W':
dims=(1, 2)) output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output return output
def relative_embedding(self): def relative_embedding(self):
......
import contextlib
import os import os
import numpy as np import numpy as np
...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except: except Exception:
pass pass
return img return img
......
...@@ -644,7 +644,7 @@ class SwinIR(nn.Module): ...@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
...@@ -844,7 +844,7 @@ class SwinIR(nn.Module): ...@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
......
...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module): ...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]): pretrained_window_size=(0, 0)):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -698,7 +698,7 @@ class Swin2SR(nn.Module): ...@@ -698,7 +698,7 @@ class Swin2SR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module): ...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
function checkBrackets(textArea, counterElt) { function checkBrackets(textArea, counterElt) {
var counts = {}; var counts = {};
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => { (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
counts[bracket] = (counts[bracket] || 0) + 1; counts[bracket] = (counts[bracket] || 0) + 1;
}); });
var errors = []; var errors = [];
...@@ -27,14 +27,14 @@ function checkBrackets(textArea, counterElt) { ...@@ -27,14 +27,14 @@ function checkBrackets(textArea, counterElt) {
function setupBracketChecking(id_prompt, id_counter) { function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter) var counter = gradioApp().getElementById(id_counter);
if (textarea && counter) { if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter)); textarea.addEventListener("input", () => checkBrackets(textarea, counter));
} }
} }
onUiLoaded(function () { onUiLoaded(function() {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter'); setupBracketChecking('img2img_prompt', 'img2img_token_counter');
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term{serach_only}'>{search_term}</span> <span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span> <span class='description'>{description}</span>
......
...@@ -662,3 +662,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, ...@@ -662,3 +662,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
</pre> </pre>
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
<pre>
MIT License
Copyright (c) 2023 Ollin Boer Bohan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
\ No newline at end of file
let currentWidth = null; let currentWidth = null;
let currentHeight = null; let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0); let arFrameTimeout = setTimeout(function() {}, 0);
function dimensionChange(e, is_width, is_height){ function dimensionChange(e, is_width, is_height) {
if(is_width){ if (is_width) {
currentWidth = e.target.value*1.0 currentWidth = e.target.value * 1.0;
} }
if(is_height){ if (is_height) {
currentHeight = e.target.value*1.0 currentHeight = e.target.value * 1.0;
} }
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block"; var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){ if (!inImg2img) {
return; return;
} }
var targetElement = null; var targetElement = null;
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img');
if(tabIndex == 0){ // img2img if (tabIndex == 0) { // img2img
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch } else if (tabIndex == 1) { //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint } else if (tabIndex == 2) { // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch } else if (tabIndex == 3) { // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if (targetElement) {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(!arPreviewRect){ if (!arPreviewRect) {
arPreviewRect = document.createElement('div') arPreviewRect = document.createElement('div');
arPreviewRect.id = "imageARPreview"; arPreviewRect.id = "imageARPreview";
gradioApp().appendChild(arPreviewRect) gradioApp().appendChild(arPreviewRect);
} }
var viewportOffset = targetElement.getBoundingClientRect(); var viewportOffset = targetElement.getBoundingClientRect();
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
var scaledx = targetElement.naturalWidth*viewportscale var scaledx = targetElement.naturalWidth * viewportscale;
var scaledy = targetElement.naturalHeight*viewportscale var scaledy = targetElement.naturalHeight * viewportscale;
var cleintRectTop = (viewportOffset.top+window.scrollY) var cleintRectTop = (viewportOffset.top + window.scrollY);
var cleintRectLeft = (viewportOffset.left+window.scrollX) var cleintRectLeft = (viewportOffset.left + window.scrollX);
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight ) var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
var arscaledx = currentWidth*arscale var arscaledx = currentWidth * arscale;
var arscaledy = currentHeight*arscale var arscaledy = currentHeight * arscale;
var arRectTop = cleintRectCentreY-(arscaledy/2) var arRectTop = cleintRectCentreY - (arscaledy / 2);
var arRectLeft = cleintRectCentreX-(arscaledx/2) var arRectLeft = cleintRectCentreX - (arscaledx / 2);
var arRectWidth = arscaledx var arRectWidth = arscaledx;
var arRectHeight = arscaledy var arRectHeight = arscaledy;
arPreviewRect.style.top = arRectTop+'px'; arPreviewRect.style.top = arRectTop + 'px';
arPreviewRect.style.left = arRectLeft+'px'; arPreviewRect.style.left = arRectLeft + 'px';
arPreviewRect.style.width = arRectWidth+'px'; arPreviewRect.style.width = arRectWidth + 'px';
arPreviewRect.style.height = arRectHeight+'px'; arPreviewRect.style.height = arRectHeight + 'px';
clearTimeout(arFrameTimeout); clearTimeout(arFrameTimeout);
arFrameTimeout = setTimeout(function(){ arFrameTimeout = setTimeout(function() {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
},2000); }, 2000);
arPreviewRect.style.display = 'block'; arPreviewRect.style.display = 'block';
...@@ -81,31 +81,33 @@ function dimensionChange(e, is_width, is_height){ ...@@ -81,31 +81,33 @@ function dimensionChange(e, is_width, is_height){
} }
onUiUpdate(function(){ onUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(arPreviewRect){ if (arPreviewRect) {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
} }
var tabImg2img = gradioApp().querySelector("#tab_img2img"); var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) { if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block"; var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){ if (inImg2img) {
let inputs = gradioApp().querySelectorAll('input'); let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){ inputs.forEach(function(e) {
var is_width = e.parentElement.id == "img2img_width" var is_width = e.parentElement.id == "img2img_width";
var is_height = e.parentElement.id == "img2img_height" var is_height = e.parentElement.id == "img2img_height";
if((is_width || is_height) && !e.classList.contains('scrollwatch')){ if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) e.addEventListener('input', function(e) {
e.classList.add('scrollwatch') dimensionChange(e, is_width, is_height);
});
e.classList.add('scrollwatch');
} }
if(is_width){ if (is_width) {
currentWidth = e.value*1.0 currentWidth = e.value * 1.0;
} }
if(is_height){ if (is_height) {
currentHeight = e.value*1.0 currentHeight = e.value * 1.0;
} }
}) });
} }
} }
}); });
contextMenuInit = function(){ var contextMenuInit = function() {
let eventListenerApplied=false; let eventListenerApplied = false;
let menuSpecs = new Map(); let menuSpecs = new Map();
const uid = function(){ const uid = function() {
return Date.now().toString(36) + Math.random().toString(36).substring(2); return Date.now().toString(36) + Math.random().toString(36).substring(2);
} };
function showContextMenu(event,element,menuEntries){ function showContextMenu(event, element, menuEntries) {
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let oldMenu = gradioApp().querySelector('#context-menu') let oldMenu = gradioApp().querySelector('#context-menu');
if(oldMenu){ if (oldMenu) {
oldMenu.remove() oldMenu.remove();
} }
let baseStyle = window.getComputedStyle(uiCurrentTab) let baseStyle = window.getComputedStyle(uiCurrentTab);
const contextMenu = document.createElement('nav') const contextMenu = document.createElement('nav');
contextMenu.id = "context-menu" contextMenu.id = "context-menu";
contextMenu.style.background = baseStyle.background contextMenu.style.background = baseStyle.background;
contextMenu.style.color = baseStyle.color contextMenu.style.color = baseStyle.color;
contextMenu.style.fontFamily = baseStyle.fontFamily contextMenu.style.fontFamily = baseStyle.fontFamily;
contextMenu.style.top = posy+'px' contextMenu.style.top = posy + 'px';
contextMenu.style.left = posx+'px' contextMenu.style.left = posx + 'px';
const contextMenuList = document.createElement('ul') const contextMenuList = document.createElement('ul');
contextMenuList.className = 'context-menu-items'; contextMenuList.className = 'context-menu-items';
contextMenu.append(contextMenuList); contextMenu.append(contextMenuList);
menuEntries.forEach(function(entry){ menuEntries.forEach(function(entry) {
let contextMenuEntry = document.createElement('a') let contextMenuEntry = document.createElement('a');
contextMenuEntry.innerHTML = entry['name'] contextMenuEntry.innerHTML = entry['name'];
contextMenuEntry.addEventListener("click", function() { contextMenuEntry.addEventListener("click", function() {
entry['func'](); entry['func']();
}) });
contextMenuList.append(contextMenuEntry); contextMenuList.append(contextMenuEntry);
}) });
gradioApp().appendChild(contextMenu) gradioApp().appendChild(contextMenu);
let menuWidth = contextMenu.offsetWidth + 4; let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4; let menuHeight = contextMenu.offsetHeight + 4;
...@@ -50,117 +50,123 @@ contextMenuInit = function(){ ...@@ -50,117 +50,123 @@ contextMenuInit = function(){
let windowWidth = window.innerWidth; let windowWidth = window.innerWidth;
let windowHeight = window.innerHeight; let windowHeight = window.innerHeight;
if ( (windowWidth - posx) < menuWidth ) { if ((windowWidth - posx) < menuWidth) {
contextMenu.style.left = windowWidth - menuWidth + "px"; contextMenu.style.left = windowWidth - menuWidth + "px";
} }
if ( (windowHeight - posy) < menuHeight ) { if ((windowHeight - posy) < menuHeight) {
contextMenu.style.top = windowHeight - menuHeight + "px"; contextMenu.style.top = windowHeight - menuHeight + "px";
} }
} }
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){ function appendContextMenuOption(targetElementSelector, entryName, entryFunction) {
var currentItems = menuSpecs.get(targetElementSelector) var currentItems = menuSpecs.get(targetElementSelector);
if(!currentItems){ if (!currentItems) {
currentItems = [] currentItems = [];
menuSpecs.set(targetElementSelector,currentItems); menuSpecs.set(targetElementSelector, currentItems);
} }
let newItem = {'id':targetElementSelector+'_'+uid(), let newItem = {
'name':entryName, id: targetElementSelector + '_' + uid(),
'func':entryFunction, name: entryName,
'isNew':true} func: entryFunction,
isNew: true
};
currentItems.push(newItem) currentItems.push(newItem);
return newItem['id'] return newItem['id'];
} }
function removeContextMenuOption(uid){ function removeContextMenuOption(uid) {
menuSpecs.forEach(function(v) { menuSpecs.forEach(function(v) {
let index = -1 let index = -1;
v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) v.forEach(function(e, ei) {
if(index>=0){ if (e['id'] == uid) {
index = ei;
}
});
if (index >= 0) {
v.splice(index, 1); v.splice(index, 1);
} }
}) });
} }
function addContextMenuEventListener(){ function addContextMenuEventListener() {
if(eventListenerApplied){ if (eventListenerApplied) {
return; return;
} }
gradioApp().addEventListener("click", function(e) { gradioApp().addEventListener("click", function(e) {
if(! e.isTrusted){ if (!e.isTrusted) {
return return;
} }
let oldMenu = gradioApp().querySelector('#context-menu') let oldMenu = gradioApp().querySelector('#context-menu');
if(oldMenu){ if (oldMenu) {
oldMenu.remove() oldMenu.remove();
} }
}); });
gradioApp().addEventListener("contextmenu", function(e) { gradioApp().addEventListener("contextmenu", function(e) {
let oldMenu = gradioApp().querySelector('#context-menu') let oldMenu = gradioApp().querySelector('#context-menu');
if(oldMenu){ if (oldMenu) {
oldMenu.remove() oldMenu.remove();
} }
menuSpecs.forEach(function(v,k) { menuSpecs.forEach(function(v, k) {
if(e.composedPath()[0].matches(k)){ if (e.composedPath()[0].matches(k)) {
showContextMenu(e,e.composedPath()[0],v) showContextMenu(e, e.composedPath()[0], v);
e.preventDefault() e.preventDefault();
} }
})
}); });
eventListenerApplied=true });
eventListenerApplied = true;
} }
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener];
} };
initResponse = contextMenuInit(); var initResponse = contextMenuInit();
appendContextMenuOption = initResponse[0]; var appendContextMenuOption = initResponse[0];
removeContextMenuOption = initResponse[1]; var removeContextMenuOption = initResponse[1];
addContextMenuEventListener = initResponse[2]; var addContextMenuEventListener = initResponse[2];
(function(){ (function() {
//Start example Context Menu Items //Start example Context Menu Items
let generateOnRepeat = function(genbuttonid,interruptbuttonid){ let generateOnRepeat = function(genbuttonid, interruptbuttonid) {
let genbutton = gradioApp().querySelector(genbuttonid); let genbutton = gradioApp().querySelector(genbuttonid);
let interruptbutton = gradioApp().querySelector(interruptbuttonid); let interruptbutton = gradioApp().querySelector(interruptbuttonid);
if(!interruptbutton.offsetParent){ if (!interruptbutton.offsetParent) {
genbutton.click(); genbutton.click();
} }
clearInterval(window.generateOnRepeatInterval) clearInterval(window.generateOnRepeatInterval);
window.generateOnRepeatInterval = setInterval(function(){ window.generateOnRepeatInterval = setInterval(function() {
if(!interruptbutton.offsetParent){ if (!interruptbutton.offsetParent) {
genbutton.click(); genbutton.click();
} }
}, },
500) 500);
} };
appendContextMenuOption('#txt2img_generate','Generate forever',function(){ appendContextMenuOption('#txt2img_generate', 'Generate forever', function() {
generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
}) });
appendContextMenuOption('#img2img_generate','Generate forever',function(){ appendContextMenuOption('#img2img_generate', 'Generate forever', function() {
generateOnRepeat('#img2img_generate','#img2img_interrupt'); generateOnRepeat('#img2img_generate', '#img2img_interrupt');
}) });
let cancelGenerateForever = function(){ let cancelGenerateForever = function() {
clearInterval(window.generateOnRepeatInterval) clearInterval(window.generateOnRepeatInterval);
} };
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever);
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever);
})(); })();
//End example Context Menu Items //End example Context Menu Items
onUiUpdate(function(){ onUiUpdate(function() {
addContextMenuEventListener() addContextMenuEventListener();
}); });
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
function isValidImageList( files ) { function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
} }
function dropReplaceImage( imgWrap, files ) { function dropReplaceImage(imgWrap, files) {
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
...@@ -14,8 +14,8 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -14,8 +14,8 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
if ( files.length === 0 ) { if (files.length === 0) {
files = new DataTransfer(); files = new DataTransfer();
files.items.add(tmpFile); files.items.add(tmpFile);
fileInput.files = files.files; fileInput.files = files.files;
...@@ -26,32 +26,32 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -26,32 +26,32 @@ function dropReplaceImage( imgWrap, files ) {
} }
}; };
if ( imgWrap.closest('#pnginfo_image') ) { if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish // special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch; const oldFetch = window.fetch;
window.fetch = async (input, options) => { window.fetch = async(input, options) => {
const response = await oldFetch(input, options); const response = await oldFetch(input, options);
if ( 'api/predict/' === input ) { if ('api/predict/' === input) {
const content = await response.text(); const content = await response.text();
window.fetch = oldFetch; window.fetch = oldFetch;
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
return new Response(content, { return new Response(content, {
status: response.status, status: response.status,
statusText: response.statusText, statusText: response.statusText,
headers: response.headers headers: response.headers
}) });
} }
return response; return response;
}; };
} else { } else {
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
} }
} }
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -65,24 +65,24 @@ window.document.addEventListener('drop', e => { ...@@ -65,24 +65,24 @@ window.document.addEventListener('drop', e => {
return; return;
} }
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if (!imgWrap) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
e.preventDefault(); e.preventDefault();
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
dropReplaceImage( imgWrap, files ); dropReplaceImage(imgWrap, files);
}); });
window.addEventListener('paste', e => { window.addEventListener('paste', e => {
const files = e.clipboardData.files; const files = e.clipboardData.files;
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(el => uiElementIsVisible(el)); .filter(el => uiElementIsVisible(el));
if ( ! visibleImageFields.length ) { if (!visibleImageFields.length) {
return; return;
} }
...@@ -93,5 +93,6 @@ window.addEventListener('paste', e => { ...@@ -93,5 +93,6 @@ window.addEventListener('paste', e => {
firstFreeImageField ? firstFreeImageField ?
firstFreeImageField : firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1] visibleImageFields[visibleImageFields.length - 1]
, files ); , files
);
}); });
function keyupEditAttention(event){ function keyupEditAttention(event) {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (!(event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp" let isPlus = event.key == "ArrowUp";
let isMinus = event.key == "ArrowDown" let isMinus = event.key == "ArrowDown";
if (!isPlus && !isMinus) return; if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
let text = target.value; let text = target.value;
function selectCurrentParenthesisBlock(OPEN, CLOSE){ function selectCurrentParenthesisBlock(OPEN, CLOSE) {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
...@@ -44,7 +44,7 @@ function keyupEditAttention(event){ ...@@ -44,7 +44,7 @@ function keyupEditAttention(event){
return true; return true;
} }
function selectCurrentWord(){ function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t"; const delimiters = opts.keyedit_delimiters + " \r\n\t";
...@@ -69,20 +69,20 @@ function keyupEditAttention(event){ ...@@ -69,20 +69,20 @@ function keyupEditAttention(event){
event.preventDefault(); event.preventDefault();
var closeCharacter = ')' var closeCharacter = ')';
var delta = opts.keyedit_precision_attention var delta = opts.keyedit_precision_attention;
if (selectionStart > 0 && text[selectionStart - 1] == '<'){ if (selectionStart > 0 && text[selectionStart - 1] == '<') {
closeCharacter = '>' closeCharacter = '>';
delta = opts.keyedit_precision_extra delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") { } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
// do not include spaces at the end // do not include spaces at the end
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1; selectionEnd -= 1;
} }
if(selectionStart == selectionEnd){ if (selectionStart == selectionEnd) {
return return;
} }
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
...@@ -97,7 +97,7 @@ function keyupEditAttention(event){ ...@@ -97,7 +97,7 @@ function keyupEditAttention(event){
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0" if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
...@@ -112,7 +112,7 @@ function keyupEditAttention(event){ ...@@ -112,7 +112,7 @@ function keyupEditAttention(event){
target.selectionStart = selectionStart; target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd; target.selectionEnd = selectionEnd;
updateInput(target) updateInput(target);
} }
addEventListener('keydown', (event) => { addEventListener('keydown', (event) => {
......
function extensions_apply(_disabled_list, _update_list, disable_all){ function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = [] var disable = [];
var update = [] var update = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if(x.name.startsWith("enable_") && ! x.checked) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7)) disable.push(x.name.substring(7));
}
if(x.name.startsWith("update_") && x.checked) if (x.name.startsWith("update_") && x.checked) {
update.push(x.name.substring(7)) update.push(x.name.substring(7));
}) }
});
restart_reload() restart_reload();
return [JSON.stringify(disable), JSON.stringify(update), disable_all] return [JSON.stringify(disable), JSON.stringify(update), disable_all];
} }
function extensions_check(){ function extensions_check() {
var disable = [] var disable = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if(x.name.startsWith("enable_") && ! x.checked) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7)) disable.push(x.name.substring(7));
}) }
});
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading..." x.innerHTML = "Loading...";
}) });
var id = randomId() var id = randomId();
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
}) });
return [id, JSON.stringify(disable)] return [id, JSON.stringify(disable)];
} }
function install_extension_from_index(button, url){ function install_extension_from_index(button, url) {
button.disabled = "disabled" button.disabled = "disabled";
button.value = "Installing..." button.value = "Installing...";
var textarea = gradioApp().querySelector('#extension_to_install textarea') var textarea = gradioApp().querySelector('#extension_to_install textarea');
textarea.value = url textarea.value = url;
updateInput(textarea) updateInput(textarea);
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click();
} }
function config_state_confirm_restore(_, config_state_name, config_restore_type) { function config_state_confirm_restore(_, config_state_name, config_restore_type) {
...@@ -63,9 +66,9 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) ...@@ -63,9 +66,9 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + "."); let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
if (confirmed) { if (confirmed) {
restart_reload(); restart_reload();
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading..." x.innerHTML = "Loading...";
}) });
} }
return [confirmed, config_state_name, config_restore_type]; return [confirmed, config_state_name, config_restore_type];
} }
This diff is collapsed.
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined; let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){ onUiUpdate(function() {
if (!txt2img_gallery) { if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img") txt2img_gallery = attachGalleryListeners("txt2img");
} }
if (!img2img_gallery) { if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img") img2img_gallery = attachGalleryListeners("img2img");
} }
if (!modal) { if (!modal) {
modal = gradioApp().getElementById('lightboxModal') modal = gradioApp().getElementById('lightboxModal');
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
} }
}); });
let modalObserver = new MutationObserver(function(mutations) { let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) { mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click() gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
}
}); });
}); });
function attachGalleryListeners(tab_name) { function attachGalleryListeners(tab_name) {
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery') var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => { gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click() gradioApp().getElementById(tab_name + "_generation_info_button").click();
}
}); });
return gallery; return gallery;
} }
// mouseover tooltips for various UI elements // mouseover tooltips for various UI elements
titles = { var titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image", "Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
...@@ -9,6 +9,7 @@ titles = { ...@@ -9,6 +9,7 @@ titles = {
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models", "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"\u{1F4D0}": "Auto detect size from img2img",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)", "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
...@@ -66,8 +67,8 @@ titles = { ...@@ -66,8 +67,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
...@@ -113,20 +114,18 @@ titles = { ...@@ -113,20 +114,18 @@ titles = {
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
} };
onUiUpdate(function(){ function updateTooltipForSpan(span) {
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
if (span.title) return; // already has a title if (span.title) return; // already has a title
let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
if(!tooltip){ if (!tooltip) {
tooltip = localization[titles[span.value]] || titles[span.value]; tooltip = localization[titles[span.value]] || titles[span.value];
} }
if(!tooltip){ if (!tooltip) {
for (const c of span.classList) { for (const c of span.classList) {
if (c in titles) { if (c in titles) {
tooltip = localization[titles[c]] || titles[c]; tooltip = localization[titles[c]] || titles[c];
...@@ -135,16 +134,35 @@ onUiUpdate(function(){ ...@@ -135,16 +134,35 @@ onUiUpdate(function(){
} }
} }
if(tooltip){ if (tooltip) {
span.title = tooltip; span.title = tooltip;
} }
}) }
gradioApp().querySelectorAll('select').forEach(function(select){ function updateTooltipForSelect(select) {
if (select.onchange != null) return; if (select.onchange != null) return;
select.onchange = function(){ select.onchange = function() {
select.title = localization[titles[select.value]] || titles[select.value] || ""; select.title = localization[titles[select.value]] || titles[select.value] || "";
};
}
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
onUiUpdate(function(m) {
m.forEach(function(record) {
record.addedNodes.forEach(function(node) {
if (observedTooltipElements[node.tagName]) {
updateTooltipForSpan(node);
}
if (node.tagName == "SELECT") {
updateTooltipForSelect(node);
}
if (node.querySelectorAll) {
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
node.querySelectorAll('select').forEach(updateTooltipForSelect);
} }
}) });
}) });
});
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
function setInactive(elem, inactive){ function setInactive(elem, inactive) {
elem.classList.toggle('inactive', !!inactive) elem.classList.toggle('inactive', !!inactive);
} }
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
} }
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
*/ */
function imageMaskResize() { function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) { if (!canvases.length) {
canvases_fixed = false; // TODO: this is unused..? window.removeEventListener('resize', imageMaskResize);
window.removeEventListener( 'resize', imageMaskResize );
return; return;
} }
const wrapper = canvases[0].closest('.touch-none'); const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling; const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) { if (!previewImage.complete) {
previewImage.addEventListener( 'load', imageMaskResize); previewImage.addEventListener('load', imageMaskResize);
return; return;
} }
...@@ -24,15 +23,15 @@ function imageMaskResize() { ...@@ -24,15 +23,15 @@ function imageMaskResize() {
const nh = previewImage.naturalHeight; const nh = previewImage.naturalHeight;
const portrait = nh > nw; const portrait = nh > nw;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`; wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`; wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`; wrapper.style.left = `0px`;
wrapper.style.top = `0px`; wrapper.style.top = `0px`;
canvases.forEach( c => { canvases.forEach(c => {
c.style.width = c.style.height = ''; c.style.width = c.style.height = '';
c.style.maxWidth = '100%'; c.style.maxWidth = '100%';
c.style.maxHeight = '100%'; c.style.maxHeight = '100%';
...@@ -41,4 +40,4 @@ function imageMaskResize() { ...@@ -41,4 +40,4 @@ function imageMaskResize() {
} }
onUiUpdate(imageMaskResize); onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize); window.addEventListener('resize', imageMaskResize);
window.onload = (function(){ window.onload = (function() {
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return; if (target.placeholder.indexOf("Prompt") == -1) return;
...@@ -10,7 +10,7 @@ window.onload = (function(){ ...@@ -10,7 +10,7 @@ window.onload = (function(){
const imgParent = gradioApp().getElementById(prompt_target); const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]'); const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
fileInput.files = files; fileInput.files = files;
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
......
This diff is collapsed.
window.addEventListener('gamepadconnected', (e) => { window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index; const index = e.gamepad.index;
let isWaiting = false; let isWaiting = false;
setInterval(async () => { setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return; if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index]; const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0]; const xValue = gamepad.axes[0];
...@@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => { ...@@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => {
} }
if (isWaiting) { if (isWaiting) {
await sleepUntil(() => { await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0] const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) { if (xValue < 0.3 && xValue > -0.3) {
return true; return true;
} }
......
// localization = {} -- the dict with translations is created by the backend // localization = {} -- the dict with translations is created by the backend
ignore_ids_for_localization={ var ignore_ids_for_localization = {
setting_sd_hypernetwork: 'OPTION', setting_sd_hypernetwork: 'OPTION',
setting_sd_model_checkpoint: 'OPTION', setting_sd_model_checkpoint: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION',
modelmerger_primary_model_name: 'OPTION', modelmerger_primary_model_name: 'OPTION',
modelmerger_secondary_model_name: 'OPTION', modelmerger_secondary_model_name: 'OPTION',
modelmerger_tertiary_model_name: 'OPTION', modelmerger_tertiary_model_name: 'OPTION',
...@@ -17,111 +16,111 @@ ignore_ids_for_localization={ ...@@ -17,111 +16,111 @@ ignore_ids_for_localization={
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
extras_upscaler_1: 'SPAN', extras_upscaler_1: 'SPAN',
extras_upscaler_2: 'SPAN', extras_upscaler_2: 'SPAN',
} };
re_num = /^[\.\d]+$/ var re_num = /^[.\d]+$/;
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
original_lines = {} var original_lines = {};
translated_lines = {} var translated_lines = {};
function hasLocalization() { function hasLocalization() {
return window.localization && Object.keys(window.localization).length > 0; return window.localization && Object.keys(window.localization).length > 0;
} }
function textNodesUnder(el){ function textNodesUnder(el) {
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
while(n=walk.nextNode()) a.push(n); while ((n = walk.nextNode())) a.push(n);
return a; return a;
} }
function canBeTranslated(node, text){ function canBeTranslated(node, text) {
if(! text) return false; if (!text) return false;
if(! node.parentElement) return false; if (!node.parentElement) return false;
var parentType = node.parentElement.nodeName var parentType = node.parentElement.nodeName;
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){ if (parentType == 'OPTION' || parentType == 'SPAN') {
var pnode = node var pnode = node;
for(var level=0; level<4; level++){ for (var level = 0; level < 4; level++) {
pnode = pnode.parentElement pnode = pnode.parentElement;
if(! pnode) break; if (!pnode) break;
if(ignore_ids_for_localization[pnode.id] == parentType) return false; if (ignore_ids_for_localization[pnode.id] == parentType) return false;
} }
} }
if(re_num.test(text)) return false; if (re_num.test(text)) return false;
if(re_emoji.test(text)) return false; if (re_emoji.test(text)) return false;
return true return true;
} }
function getTranslation(text){ function getTranslation(text) {
if(! text) return undefined if (!text) return undefined;
if(translated_lines[text] === undefined){ if (translated_lines[text] === undefined) {
original_lines[text] = 1 original_lines[text] = 1;
} }
tl = localization[text] var tl = localization[text];
if(tl !== undefined){ if (tl !== undefined) {
translated_lines[tl] = 1 translated_lines[tl] = 1;
} }
return tl return tl;
} }
function processTextNode(node){ function processTextNode(node) {
var text = node.textContent.trim() var text = node.textContent.trim();
if(! canBeTranslated(node, text)) return if (!canBeTranslated(node, text)) return;
tl = getTranslation(text) var tl = getTranslation(text);
if(tl !== undefined){ if (tl !== undefined) {
node.textContent = tl node.textContent = tl;
} }
} }
function processNode(node){ function processNode(node) {
if(node.nodeType == 3){ if (node.nodeType == 3) {
processTextNode(node) processTextNode(node);
return return;
} }
if(node.title){ if (node.title) {
tl = getTranslation(node.title) let tl = getTranslation(node.title);
if(tl !== undefined){ if (tl !== undefined) {
node.title = tl node.title = tl;
} }
} }
if(node.placeholder){ if (node.placeholder) {
tl = getTranslation(node.placeholder) let tl = getTranslation(node.placeholder);
if(tl !== undefined){ if (tl !== undefined) {
node.placeholder = tl node.placeholder = tl;
} }
} }
textNodesUnder(node).forEach(function(node){ textNodesUnder(node).forEach(function(node) {
processTextNode(node) processTextNode(node);
}) });
} }
function dumpTranslations(){ function dumpTranslations() {
if(!hasLocalization()) { if (!hasLocalization()) {
// If we don't have any localization, // If we don't have any localization,
// we will not have traversed the app to find // we will not have traversed the app to find
// original_lines, so do that now. // original_lines, so do that now.
processNode(gradioApp()); processNode(gradioApp());
} }
var dumped = {} var dumped = {};
if (localization.rtl) { if (localization.rtl) {
dumped.rtl = true; dumped.rtl = true;
} }
for (const text in original_lines) { for (const text in original_lines) {
if(dumped[text] !== undefined) continue; if (dumped[text] !== undefined) continue;
dumped[text] = localization[text] || text; dumped[text] = localization[text] || text;
} }
...@@ -129,7 +128,7 @@ function dumpTranslations(){ ...@@ -129,7 +128,7 @@ function dumpTranslations(){
} }
function download_localization() { function download_localization() {
var text = JSON.stringify(dumpTranslations(), null, 4) var text = JSON.stringify(dumpTranslations(), null, 4);
var element = document.createElement('a'); var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
...@@ -142,20 +141,20 @@ function download_localization() { ...@@ -142,20 +141,20 @@ function download_localization() {
document.body.removeChild(element); document.body.removeChild(element);
} }
document.addEventListener("DOMContentLoaded", function () { document.addEventListener("DOMContentLoaded", function() {
if (!hasLocalization()) { if (!hasLocalization()) {
return; return;
} }
onUiUpdate(function (m) { onUiUpdate(function(m) {
m.forEach(function (mutation) { m.forEach(function(mutation) {
mutation.addedNodes.forEach(function (node) { mutation.addedNodes.forEach(function(node) {
processNode(node) processNode(node);
}) });
});
}); });
})
processNode(gradioApp()) processNode(gradioApp());
if (localization.rtl) { // if the language is from right to left, if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load (new MutationObserver((mutations, observer) => { // wait for the style to load
...@@ -170,8 +169,8 @@ document.addEventListener("DOMContentLoaded", function () { ...@@ -170,8 +169,8 @@ document.addEventListener("DOMContentLoaded", function () {
} }
} }
} }
})
}); });
})).observe(gradioApp(), { childList: true }); });
})).observe(gradioApp(), {childList: true});
} }
}) });
...@@ -4,14 +4,14 @@ let lastHeadImg = null; ...@@ -4,14 +4,14 @@ let lastHeadImg = null;
let notificationButton = null; let notificationButton = null;
onUiUpdate(function(){ onUiUpdate(function() {
if(notificationButton == null){ if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications') notificationButton = gradioApp().getElementById('request_notifications');
if(notificationButton != null){ if (notificationButton != null) {
notificationButton.addEventListener('click', () => { notificationButton.addEventListener('click', () => {
void Notification.requestPermission(); void Notification.requestPermission();
},true); }, true);
} }
} }
...@@ -42,7 +42,7 @@ onUiUpdate(function(){ ...@@ -42,7 +42,7 @@ onUiUpdate(function(){
} }
); );
notification.onclick = function(_){ notification.onclick = function(_) {
parent.focus(); parent.focus();
this.close(); this.close();
}; };
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(){ function rememberGallerySelection() {
} }
function getGallerySelectedIndex(){ function getGallerySelectedIndex() {
} }
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler) {
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
xhr.open("POST", url, true); xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json"); xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () { xhr.onreadystatechange = function() {
if (xhr.readyState === 4) { if (xhr.readyState === 4) {
if (xhr.status === 200) { if (xhr.status === 200) {
try { try {
var js = JSON.parse(xhr.responseText); var js = JSON.parse(xhr.responseText);
handler(js) handler(js);
} catch (error) { } catch (error) {
console.error(error); console.error(error);
errorHandler() errorHandler();
} }
} else{ } else {
errorHandler() errorHandler();
} }
} }
}; };
...@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){ ...@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){
xhr.send(js); xhr.send(js);
} }
function pad2(x){ function pad2(x) {
return x<10 ? '0'+x : x return x < 10 ? '0' + x : x;
} }
function formatTime(secs){ function formatTime(secs) {
if(secs > 3600){ if (secs > 3600) {
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
} else if(secs > 60){ } else if (secs > 60) {
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
} else{ } else {
return Math.floor(secs) + "s" return Math.floor(secs) + "s";
} }
} }
function setTitle(progress){ function setTitle(progress) {
var title = 'Stable Diffusion' var title = 'Stable Diffusion';
if(opts.show_progress_in_title && progress){ if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title; title = '[' + progress.trim() + '] ' + title;
} }
if(document.title != title){ if (document.title != title) {
document.title = title; document.title = title;
} }
} }
function randomId(){ function randomId() {
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
} }
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update // calls onProgress every time there is a progress update
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
var dateStart = new Date() var dateStart = new Date();
var wasEverActive = false var wasEverActive = false;
var parentProgressbar = progressbarContainer.parentNode var parentProgressbar = progressbarContainer.parentNode;
var parentGallery = gallery ? gallery.parentNode : null var parentGallery = gallery ? gallery.parentNode : null;
var divProgress = document.createElement('div') var divProgress = document.createElement('div');
divProgress.className='progressDiv' divProgress.className = 'progressDiv';
divProgress.style.display = opts.show_progressbar ? "block" : "none" divProgress.style.display = opts.show_progressbar ? "block" : "none";
var divInner = document.createElement('div') var divInner = document.createElement('div');
divInner.className='progress' divInner.className = 'progress';
divProgress.appendChild(divInner) divProgress.appendChild(divInner);
parentProgressbar.insertBefore(divProgress, progressbarContainer) parentProgressbar.insertBefore(divProgress, progressbarContainer);
if(parentGallery){ if (parentGallery) {
var livePreview = document.createElement('div') var livePreview = document.createElement('div');
livePreview.className='livePreview' livePreview.className = 'livePreview';
parentGallery.insertBefore(livePreview, gallery) parentGallery.insertBefore(livePreview, gallery);
} }
var removeProgressBar = function(){ var removeProgressBar = function() {
setTitle("") setTitle("");
parentProgressbar.removeChild(divProgress) parentProgressbar.removeChild(divProgress);
if(parentGallery) parentGallery.removeChild(livePreview) if (parentGallery) parentGallery.removeChild(livePreview);
atEnd() atEnd();
} };
var fun = function(id_task, id_live_preview){ var fun = function(id_task, id_live_preview) {
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
if(res.completed){ if (res.completed) {
removeProgressBar() removeProgressBar();
return return;
} }
var rect = progressbarContainer.getBoundingClientRect() var rect = progressbarContainer.getBoundingClientRect();
if(rect.width){ if (rect.width) {
divProgress.style.width = rect.width + "px"; divProgress.style.width = rect.width + "px";
} }
let progressText = "" let progressText = "";
divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.width = ((res.progress || 0) * 100.0) + '%';
divInner.style.background = res.progress ? "" : "transparent" divInner.style.background = res.progress ? "" : "transparent";
if(res.progress > 0){ if (res.progress > 0) {
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
} }
if(res.eta){ if (res.eta) {
progressText += " ETA: " + formatTime(res.eta) progressText += " ETA: " + formatTime(res.eta);
} }
setTitle(progressText) setTitle(progressText);
if(res.textinfo && res.textinfo.indexOf("\n") == -1){ if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
progressText = res.textinfo + " " + progressText progressText = res.textinfo + " " + progressText;
} }
divInner.textContent = progressText divInner.textContent = progressText;
var elapsedFromStart = (new Date() - dateStart) / 1000 var elapsedFromStart = (new Date() - dateStart) / 1000;
if(res.active) wasEverActive = true; if (res.active) wasEverActive = true;
if(! res.active && wasEverActive){ if (!res.active && wasEverActive) {
removeProgressBar() removeProgressBar();
return return;
} }
if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){ if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
removeProgressBar() removeProgressBar();
return return;
} }
if(res.live_preview && gallery){ if (res.live_preview && gallery) {
var rect = gallery.getBoundingClientRect() rect = gallery.getBoundingClientRect();
if(rect.width){ if (rect.width) {
livePreview.style.width = rect.width + "px" livePreview.style.width = rect.width + "px";
livePreview.style.height = rect.height + "px" livePreview.style.height = rect.height + "px";
} }
var img = new Image(); var img = new Image();
img.onload = function() { img.onload = function() {
livePreview.appendChild(img) livePreview.appendChild(img);
if(livePreview.childElementCount > 2){ if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild) livePreview.removeChild(livePreview.firstElementChild);
}
} }
};
img.src = res.live_preview; img.src = res.live_preview;
} }
if(onProgress){ if (onProgress) {
onProgress(res) onProgress(res);
} }
setTimeout(() => { setTimeout(() => {
fun(id_task, res.id_live_preview); fun(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500) }, opts.live_preview_refresh_period || 500);
}, function(){ }, function() {
removeProgressBar() removeProgressBar();
}) });
} };
fun(id_task, 0) fun(id_task, 0);
} }
function start_training_textual_inversion(){ function start_training_textual_inversion() {
gradioApp().querySelector('#ti_error').innerHTML='' gradioApp().querySelector('#ti_error').innerHTML = '';
var id = randomId() var id = randomId();
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
}) });
var res = args_to_array(arguments) var res = args_to_array(arguments);
res[0] = id res[0] = id;
return res return res;
} }
This diff is collapsed.
// various hints and extra info for the settings tab // various hints and extra info for the settings tab
onUiLoaded(function(){ var settingsHintsSetup = false;
createLink = function(elem_id, text, href){
var a = document.createElement('A') onOptionsChanged(function() {
a.textContent = text if (settingsHintsSetup) return;
a.target = '_blank'; settingsHintsSetup = true;
elem = gradioApp().querySelector('#'+elem_id) gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
elem.insertBefore(a, elem.querySelector('label')) var name = div.id.substr(8);
var commentBefore = opts._comments_before[name];
return a var commentAfter = opts._comments_after[name];
if (!commentBefore && !commentAfter) return;
var span = null;
if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
else span = div.querySelector('label span').firstChild;
if (!span) return;
if (commentBefore) {
var comment = document.createElement('DIV');
comment.className = 'settings-comment';
comment.innerHTML = commentBefore;
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
span.parentElement.insertBefore(comment, span);
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
} }
if (commentAfter) {
comment = document.createElement('DIV');
comment.className = 'settings-comment';
comment.innerHTML = commentAfter;
span.parentElement.insertBefore(comment, span.nextSibling);
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
}
});
});
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory" function settingsHintsShowQuicksettings() {
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory" requestGet("./internal/quicksettings-hint", {}, function(data) {
var table = document.createElement('table');
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){ table.className = 'settings-value-table';
requestGet("./internal/quicksettings-hint", {}, function(data){
var table = document.createElement('table')
table.className = 'settings-value-table'
data.forEach(function(obj){ data.forEach(function(obj) {
var tr = document.createElement('tr') var tr = document.createElement('tr');
var td = document.createElement('td') var td = document.createElement('td');
td.textContent = obj.name td.textContent = obj.name;
tr.appendChild(td) tr.appendChild(td);
var td = document.createElement('td') td = document.createElement('td');
td.textContent = obj.label td.textContent = obj.label;
tr.appendChild(td) tr.appendChild(td);
table.appendChild(tr) table.appendChild(tr);
}) });
popup(table); popup(table);
})
}); });
}) }
...@@ -3,25 +3,23 @@ import subprocess ...@@ -3,25 +3,23 @@ import subprocess
import os import os
import sys import sys
import importlib.util import importlib.util
import shlex
import platform import platform
import json import json
from functools import lru_cache
from modules import cmd_args from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
stored_git_tag = None
dir_repos = "repositories" dir_repos = "repositories"
# Whether to default to printing command output
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
...@@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning. ...@@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning.
""") """)
@lru_cache()
def commit_hash(): def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try: try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip() return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
stored_commit_hash = "<none>" return "<none>"
return stored_commit_hash
@lru_cache()
def git_tag(): def git_tag():
global stored_git_tag
if stored_git_tag is not None:
return stored_git_tag
try: try:
stored_git_tag = run(f"{git} describe --tags").strip() return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
stored_git_tag = "<none>" return "<none>"
return stored_git_tag
def run(command, desc=None, errdesc=None, custom_env=None, live=False): def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None: if desc is not None:
print(desc) print(desc)
if live: run_kwargs = {
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) "args": command,
if result.returncode != 0: "shell": True,
raise RuntimeError(f"""{errdesc or 'Error running command'}. "env": os.environ if custom_env is None else custom_env,
Command: {command} "encoding": 'utf8',
Error code: {result.returncode}""") "errors": 'ignore',
}
return "" if not live:
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) result = subprocess.run(**run_kwargs)
if result.returncode != 0: if result.returncode != 0:
error_bits = [
f"{errdesc or 'Error running command'}.",
f"Command: {command}",
f"Error code: {result.returncode}",
]
if result.stdout:
error_bits.append(f"stdout: {result.stdout}")
if result.stderr:
error_bits.append(f"stderr: {result.stderr}")
raise RuntimeError("\n".join(error_bits))
message = f"""{errdesc or 'Error running command'}. return (result.stdout or "")
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0
def is_installed(package): def is_installed(package):
...@@ -131,11 +116,7 @@ def repo_dir(name): ...@@ -131,11 +116,7 @@ def repo_dir(name):
return os.path.join(script_path, dir_repos, name) return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None): def run_pip(command, desc=None, live=default_command_live):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(command, desc=None, live=False):
if args.skip_install: if args.skip_install:
return return
...@@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False): ...@@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False):
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code): def check_run_python(code: str) -> bool:
return check_run(f'"{python}" -c "{code}"') result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
return result.returncode == 0
def git_clone(url, dir, name, commithash=None): def git_clone(url, dir, name, commithash=None):
...@@ -237,13 +219,14 @@ def run_extensions_installers(settings_file): ...@@ -237,13 +219,14 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
...@@ -270,8 +253,11 @@ def prepare_environment(): ...@@ -270,8 +253,11 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not args.skip_torch_cuda_test: if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") raise RuntimeError(
'Torch is not able to use GPU; '
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
)
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
...@@ -294,8 +280,8 @@ def prepare_environment(): ...@@ -294,8 +280,8 @@ def prepare_environment():
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip(f"install {xformers_package}", "xformers") run_pip(f"install {xformers_package}", "xformers")
if not is_installed("pyngrok") and args.ngrok: if not is_installed("ngrok") and args.ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install ngrok", "ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
......
This diff is collapsed.
...@@ -223,8 +223,9 @@ for key in _options: ...@@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'): if(_options[key].dest != 'help'):
flag = _options[key] flag = _options[key]
_type = str _type = str
if _options[key].default is not None: _type = type(_options[key].default) if _options[key].default is not None:
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) _type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags) FlagsModel = create_model("Flags", **flags)
...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel): ...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ScriptsList(BaseModel): class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
class ScriptArg(BaseModel):
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
import argparse import argparse
import json
import os import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision", ...@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
...@@ -103,3 +105,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch ...@@ -103,3 +105,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math import math
import numpy as np
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List from typing import Optional
from modules.codeformer.vqgan_arch import * from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5): def calc_mean_std(feat, eps=1e-5):
...@@ -163,8 +161,8 @@ class Fuse_sft_block(nn.Module): ...@@ -163,8 +161,8 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'], connect_list=('32', '64', '128', '256'),
fix_modules=['quantize','generator']): fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None: if fix_modules is not None:
......
...@@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut ...@@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
''' '''
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
...@@ -328,7 +326,7 @@ class Generator(nn.Module): ...@@ -328,7 +326,7 @@ class Generator(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
...@@ -339,7 +337,7 @@ class VQAutoEncoder(nn.Module): ...@@ -339,7 +337,7 @@ class VQAutoEncoder(nn.Module):
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer self.quantizer_type = quantizer
self.encoder = Encoder( self.encoder = Encoder(
self.in_channels, self.in_channels,
......
...@@ -33,11 +33,9 @@ def setup_model(dirname): ...@@ -33,11 +33,9 @@ def setup_model(dirname):
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url from basicsr.utils import img2tensor, tensor2img
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
...@@ -96,7 +94,7 @@ def setup_model(dirname): ...@@ -96,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face() self.face_helper.align_warp_face()
for idx, cropped_face in enumerate(self.face_helper.cropped_faces): for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
......
...@@ -14,7 +14,7 @@ from collections import OrderedDict ...@@ -14,7 +14,7 @@ from collections import OrderedDict
import git import git
from modules import shared, extensions from modules import shared, extensions
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict() all_config_states = OrderedDict()
...@@ -35,7 +35,7 @@ def list_config_states(): ...@@ -35,7 +35,7 @@ def list_config_states():
j["filepath"] = path j["filepath"] = path
config_states.append(j) config_states.append(j)
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)) config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states: for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"])) timestamp = time.asctime(time.gmtime(cs["created_at"]))
...@@ -83,6 +83,8 @@ def get_extension_config(): ...@@ -83,6 +83,8 @@ def get_extension_config():
ext_config = {} ext_config = {}
for ext in extensions.extensions: for ext in extensions.extensions:
ext.read_info_from_repo()
entry = { entry = {
"name": ext.name, "name": ext.name,
"path": ext.path, "path": ext.path,
......
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
import re import re
import torch import torch
from PIL import Image
import numpy as np import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared from modules import modelloader, paths, deepbooru_model, devices, images, shared
...@@ -79,7 +78,7 @@ class DeepDanbooru: ...@@ -79,7 +78,7 @@ class DeepDanbooru:
res = [] res = []
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")]) filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]: for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag] probability = probability_dict[tag]
......
...@@ -65,7 +65,7 @@ def enable_tf32(): ...@@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
......
...@@ -6,7 +6,7 @@ from PIL import Image ...@@ -6,7 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
...@@ -16,9 +16,7 @@ def mod2normal(state_dict): ...@@ -16,9 +16,7 @@ def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer # this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict: if 'conv_first.weight' in state_dict:
crt_net = {} crt_net = {}
items = [] items = list(state_dict)
for k, v in state_dict.items():
items.append(k)
crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias'] crt_net['model.0.bias'] = state_dict['conv_first.bias']
...@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23): ...@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0 re8x = 0
crt_net = {} crt_net = {}
items = [] items = list(state_dict)
for k, v in state_dict.items():
items.append(k)
crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias'] crt_net['model.0.bias'] = state_dict['conv_first.bias']
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from collections import OrderedDict from collections import OrderedDict
import math import math
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias= ...@@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
padding = padding if pad_type == 'zero' else 0 padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D': if convtype=='PartialConv2D':
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups) dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D': elif convtype=='DeformConv2D':
from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups) dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D': elif convtype=='Conv3D':
......
import os import os
import sys import sys
import threading
import traceback import traceback
import time
from datetime import datetime
import git import git
from modules import shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = [] extensions = []
...@@ -25,6 +24,8 @@ def active(): ...@@ -25,6 +24,8 @@ def active():
class Extension: class Extension:
lock = threading.Lock()
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name self.name = name
self.path = path self.path = path
...@@ -43,8 +44,13 @@ class Extension: ...@@ -43,8 +44,13 @@ class Extension:
if self.is_builtin or self.have_info_from_repo: if self.is_builtin or self.have_info_from_repo:
return return
self.have_info_from_repo = True with self.lock:
if self.have_info_from_repo:
return
self.do_read_info_from_repo()
def do_read_info_from_repo(self):
repo = None repo = None
try: try:
if os.path.exists(os.path.join(self.path, ".git")): if os.path.exists(os.path.join(self.path, ".git")):
...@@ -59,18 +65,18 @@ class Extension: ...@@ -59,18 +65,18 @@ class Extension:
try: try:
self.status = 'unknown' self.status = 'unknown'
self.remote = next(repo.remote().urls, None) self.remote = next(repo.remote().urls, None)
head = repo.head.commit
self.commit_date = repo.head.commit.committed_date self.commit_date = repo.head.commit.committed_date
ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch: if repo.active_branch:
self.branch = repo.active_branch.name self.branch = repo.active_branch.name
self.commit_hash = head.hexsha self.commit_hash = repo.head.commit.hexsha
self.version = f'{self.commit_hash[:8]} ({ts})' self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
except Exception as ex: except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None self.remote = None
self.have_info_from_repo = True
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
from modules import scripts from modules import scripts
......
...@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data): ...@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call """call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks""" deactivate for all remaining registered networks"""
for extra_network_name, extra_network_args in extra_network_data.items(): for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None) extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None: if extra_network is None:
continue continue
......
from modules import extra_networks, shared, extra_networks from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
......
...@@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_instruct_pix2pix_model = False result_is_instruct_pix2pix_model = False
if theta_func2: if theta_func2:
shared.state.textinfo = f"Loading B" shared.state.textinfo = "Loading B"
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else: else:
theta_1 = None theta_1 = None
if theta_func1: if theta_func1:
shared.state.textinfo = f"Loading C" shared.state.textinfo = "Loading C"
print(f"Loading {tertiary_model_info.filename}...") print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
...@@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving" shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None} metadata = None
if save_metadata: if save_metadata:
metadata = {"format": "pt"}
merge_recipe = { merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger "type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256, "primary_model_hash": primary_model_info.sha256,
...@@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
} }
metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {}
def add_model_metadata(checkpoint_info): def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash() checkpoint_info.calculate_shorthash()
metadata["sd_merge_models"][checkpoint_info.sha256] = { sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name, "name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash, "legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None) "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
} }
metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {})) sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info) add_model_metadata(primary_model_info)
if secondary_model_info: if secondary_model_info:
...@@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info: if tertiary_model_info:
add_model_metadata(tertiary_model_info) add_model_metadata(tertiary_model_info)
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
......
import base64 import base64
import html
import io import io
import math
import os import os
import re import re
from pathlib import Path
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
...@@ -23,14 +19,14 @@ registered_param_bindings = [] ...@@ -23,14 +19,14 @@ registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button self.paste_button = paste_button
self.tabname = tabname self.tabname = tabname
self.source_text_component = source_text_component self.source_text_component = source_text_component
self.source_image_component = source_image_component self.source_image_component = source_image_component
self.source_tabname = source_tabname self.source_tabname = source_tabname
self.override_settings_component = override_settings_component self.override_settings_component = override_settings_component
self.paste_field_names = paste_field_names self.paste_field_names = paste_field_names or []
def reset(): def reset():
...@@ -251,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -251,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
lines.append(lastline) lines.append(lastline)
lastline = '' lastline = ''
for i, line in enumerate(lines): for line in lines:
line = line.strip() line = line.strip()
if line.startswith("Negative prompt:"): if line.startswith("Negative prompt:"):
done_with_prompt = True done_with_prompt = True
...@@ -312,6 +308,8 @@ infotext_to_setting_name_mapping = [ ...@@ -312,6 +308,8 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'), ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'), ('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'), ('UniPC lower order final', 'uni_pc_lower_order_final'),
('Token merging ratio', 'token_merging_ratio'),
('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
] ]
......
...@@ -78,7 +78,7 @@ def setup_model(dirname): ...@@ -78,7 +78,7 @@ def setup_model(dirname):
try: try:
from gfpgan import GFPGANer from gfpgan import GFPGANer
from facexlib import detection, parsing from facexlib import detection, parsing # noqa: F401
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
......
import csv
import datetime import datetime
import glob import glob
import html import html
...@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler ...@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque from collections import deque
from statistics import stdev, mean from statistics import stdev, mean
...@@ -178,34 +177,34 @@ class Hypernetwork: ...@@ -178,34 +177,34 @@ class Hypernetwork:
def weights(self): def weights(self):
res = [] res = []
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
res += layer.parameters() res += layer.parameters()
return res return res
def train(self, mode=True): def train(self, mode=True):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.train(mode=mode) layer.train(mode=mode)
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = mode param.requires_grad = mode
def to(self, device): def to(self, device):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.to(device) layer.to(device)
return self return self
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.multiplier = multiplier layer.multiplier = multiplier
return self return self
def eval(self): def eval(self):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.eval() layer.eval()
for param in layer.parameters(): for param in layer.parameters():
...@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): ...@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
...@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi ...@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
try: try:
sd_hijack_checkpoint.add() sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
......
import html import html
import os
import re
import gradio as gr import gradio as gr
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):
......
...@@ -13,17 +13,24 @@ import numpy as np ...@@ -13,17 +13,24 @@ import numpy as np
import piexif import piexif
import piexif.helper import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string import string
import json import json
import hashlib import hashlib
from modules import sd_samplers, shared, script_callbacks, errors from modules import sd_samplers, shared, script_callbacks, errors
from modules.shared import opts, cmd_opts from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
def get_font(fontsize: int):
try:
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
except Exception:
return ImageFont.truetype(roboto_ttf_file, fontsize)
def image_grid(imgs, batch_size=1, rows=None): def image_grid(imgs, batch_size=1, rows=None):
if rows is None: if rows is None:
if opts.n_rows > 0: if opts.n_rows > 0:
...@@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): ...@@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
lines.append(word) lines.append(word)
return lines return lines
def get_font(fontsize):
try:
return ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize): def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines): for line in lines:
fnt = initial_fnt fnt = initial_fnt
fontsize = initial_fontsize fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
...@@ -409,13 +410,13 @@ class FilenameGenerator: ...@@ -409,13 +410,13 @@ class FilenameGenerator:
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try: try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _: except pytz.exceptions.UnknownTimeZoneError:
time_zone = None time_zone = None
time_zone_time = time_datetime.astimezone(time_zone) time_zone_time = time_datetime.astimezone(time_zone)
try: try:
formatted_time = time_zone_time.strftime(time_format) formatted_time = time_zone_time.strftime(time_format)
except (ValueError, TypeError) as _: except (ValueError, TypeError):
formatted_time = time_zone_time.strftime(self.default_time_format) formatted_time = time_zone_time.strftime(self.default_time_format)
return sanitize_filename_part(formatted_time, replace_spaces=False) return sanitize_filename_part(formatted_time, replace_spaces=False)
...@@ -472,15 +473,52 @@ def get_next_sequence_number(path, basename): ...@@ -472,15 +473,52 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(parts[0]), result)
except ValueError: except ValueError:
pass pass
return result + 1 return result + 1
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
existing_pnginfo = existing_pnginfo or {}
if opts.enable_pnginfo:
existing_pnginfo['parameters'] = geninfo
if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
pnginfo_data.add_text(k, str(v))
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image.mode == 'RGBA':
image = image.convert("RGB")
elif image.mode == 'I;16':
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and geninfo is not None:
exif_bytes = piexif.dump({
"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
},
})
piexif.insert(exif_bytes, filename)
else:
image.save(filename, format=image_format, quality=opts.jpeg_quality)
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
"""Save an image. """Save an image.
...@@ -565,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -565,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
info = params.pnginfo.get(pnginfo_section_name, None) info = params.pnginfo.get(pnginfo_section_name, None)
def _atomically_save_image(image_to_save, filename_without_extension, extension): def _atomically_save_image(image_to_save, filename_without_extension, extension):
# save image with .tmp extension to avoid race condition when another process detects new image in the directory """
save image with .tmp extension to avoid race condition when another process detects new image in the directory
"""
temp_file_path = f"{filename_without_extension}.tmp" temp_file_path = f"{filename_without_extension}.tmp"
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png': save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
pnginfo_data = PngImagePlugin.PngInfo()
if opts.enable_pnginfo:
for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({
"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
},
})
piexif.insert(exif_bytes, temp_file_path)
else:
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
# atomically rename the file with correct extension
os.replace(temp_file_path, filename_without_extension + extension) os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename) fullfn_without_extension, extension = os.path.splitext(params.filename)
......
import math
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from modules import devices, sd_samplers from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts import modules.scripts
...@@ -59,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -59,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
# try to find corresponding mask for an image using simple filename matching # try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case) # if not found use first one ("same mask for all images" use-case)
if not mask_image_path in inpaint_masks: if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0] mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path) mask_image = Image.open(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image
......
...@@ -11,7 +11,6 @@ import torch.hub ...@@ -11,7 +11,6 @@ import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
...@@ -160,7 +159,7 @@ class InterrogateModels: ...@@ -160,7 +159,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate) text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
...@@ -208,8 +207,8 @@ class InterrogateModels: ...@@ -208,8 +207,8 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
for name, topn, items in self.categories(): for cat in self.categories():
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches: for match, score in matches:
if shared.opts.interrogate_return_ranks: if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})" res += f", ({match}:{score/100:.3f})"
......
import torch import torch
import platform import platform
from modules import paths
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
......
import glob
import os import os
import shutil import shutil
import importlib import importlib
...@@ -40,7 +39,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -40,7 +39,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if os.path.islink(full_path) and not os.path.exists(full_path): if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}") print(f"Skipping broken symlink: {full_path}")
continue continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
continue continue
if full_path not in output: if full_path not in output:
output.append(full_path) output.append(full_path)
...@@ -108,12 +107,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): ...@@ -108,12 +107,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
print(f"Moving {file} from {src_path} to {dest_path}.") print(f"Moving {file} from {src_path} to {dest_path}.")
try: try:
shutil.move(fullpath, dest_path) shutil.move(fullpath, dest_path)
except: except Exception:
pass pass
if len(os.listdir(src_path)) == 0: if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}") print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True) shutil.rmtree(src_path, True)
except: except Exception:
pass pass
...@@ -127,7 +126,7 @@ def load_upscalers(): ...@@ -127,7 +126,7 @@ def load_upscalers():
full_model = f"modules.{model_name}_model" full_model = f"modules.{model_name}_model"
try: try:
importlib.import_module(full_model) importlib.import_module(full_model)
except: except Exception:
pass pass
datas = [] datas = []
......
This diff is collapsed.
from .sampler import UniPCSampler from .sampler import UniPCSampler # noqa: F401
...@@ -54,7 +54,8 @@ class UniPCSampler(object): ...@@ -54,7 +54,8 @@ class UniPCSampler(object):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0] while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0] cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
......
This diff is collapsed.
from pyngrok import ngrok, conf, exception import ngrok
def connect(token, port, region): # Connect to ngrok for ingress
def connect(token, port, options):
account = None account = None
if token is None: if token is None:
token = 'None' token = 'None'
...@@ -10,28 +11,19 @@ def connect(token, port, region): ...@@ -10,28 +11,19 @@ def connect(token, port, region):
token, username, password = token.split(':', 2) token, username, password = token.split(':', 2)
account = f"{username}:{password}" account = f"{username}:{password}"
config = conf.PyngrokConfig( # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
auth_token=token, region=region if not options.get('authtoken_from_env'):
) options['authtoken'] = token
if account:
options['basic_auth'] = account
if not options.get('session_metadata'):
options['session_metadata'] = 'stable-diffusion-webui'
# Guard for existing tunnels
existing = ngrok.get_tunnels(pyngrok_config=config)
if existing:
for established in existing:
# Extra configuration in the case that the user is also using ngrok for other tunnels
if established.config['addr'][-4:] == str(port):
public_url = existing[0].public_url
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
'You can use this link after the launch is complete.')
return
try: try:
if account is None: public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url except Exception as e:
else: print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
else: else:
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
......
import os import os
import sys import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
import modules.safe import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data # data_path = cmd_opts_pre.data
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import re
import os import os
import torch import torch
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment