Commit 182330ae authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into ngrok-py

parents 0d31f20c 983f2c49
extensions
extensions-disabled
repositories
venv
\ No newline at end of file
module.exports = {
env: {
browser: true,
es2021: true,
},
extends: "eslint:recommended",
parserOptions: {
ecmaVersion: "latest",
},
rules: {
"arrow-spacing": "error",
"block-spacing": "error",
"brace-style": "error",
"comma-dangle": ["error", "only-multiline"],
"comma-spacing": "error",
"comma-style": ["error", "last"],
"curly": ["error", "multi-line", "consistent"],
"eol-last": "error",
"func-call-spacing": "error",
"function-call-argument-newline": ["error", "consistent"],
"function-paren-newline": ["error", "consistent"],
"indent": ["error", 4],
"key-spacing": "error",
"keyword-spacing": "error",
"linebreak-style": ["error", "unix"],
"no-extra-semi": "error",
"no-mixed-spaces-and-tabs": "error",
"no-trailing-spaces": "error",
"no-whitespace-before-property": "error",
"object-curly-newline": ["error", {consistent: true, multiline: true}],
"quote-props": ["error", "consistent-as-needed"],
"semi": ["error", "always"],
"semi-spacing": "error",
"semi-style": ["error", "last"],
"space-before-blocks": "error",
"space-before-function-paren": ["error", "never"],
"space-in-parens": ["error", "never"],
"space-infix-ops": "error",
"space-unary-ops": "error",
"switch-colon-spacing": "error",
"template-curly-spacing": ["error", "never"],
"unicode-bom": "error",
"no-multi-spaces": "error",
"object-curly-spacing": ["error", "never"],
"operator-linebreak": ["error", "after"],
"no-unused-vars": "off",
"no-redeclare": "off",
},
globals: {
// this file
module: "writable",
//script.js
gradioApp: "writable",
onUiLoaded: "writable",
onUiUpdate: "writable",
onOptionsChanged: "writable",
uiCurrentTab: "writable",
uiElementIsVisible: "writable",
executeCallbacks: "writable",
//ui.js
opts: "writable",
all_gallery_buttons: "writable",
selected_gallery_button: "writable",
selected_gallery_index: "writable",
args_to_array: "writable",
switch_to_txt2img: "writable",
switch_to_img2img_tab: "writable",
switch_to_img2img: "writable",
switch_to_sketch: "writable",
switch_to_inpaint: "writable",
switch_to_inpaint_sketch: "writable",
switch_to_extras: "writable",
get_tab_index: "writable",
create_submit_args: "writable",
restart_reload: "writable",
updateInput: "writable",
//extraNetworks.js
requestGet: "writable",
popup: "writable",
// from python
localization: "writable",
// progrssbar.js
randomId: "writable",
requestProgress: "writable",
// imageviewer.js
modalPrevImage: "writable",
modalNextImage: "writable",
}
};
...@@ -47,6 +47,15 @@ body: ...@@ -47,6 +47,15 @@ body:
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
...@@ -59,6 +68,18 @@ body: ...@@ -59,6 +68,18 @@ body:
- iOS - iOS
- Android - Android
- Other/Cloud - Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown - type: dropdown
id: browsers id: browsers
attributes: attributes:
......
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests name: Run Linting/Formatting on Pull Requests
on: on:
- push - push
- pull_request - pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs: jobs:
lint: lint-python:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - uses: actions/setup-python@v4
uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.11
cache: pip # NB: there's no cache: pip here since we're not installing anything
cache-dependency-path: | # from the requirements.txt file(s) in the repository; it's faster
**/requirements*txt # not to have GHA download an (at the time of writing) 4 GB cache
- name: Install PyLint # of PyTorch and other dependencies.
run: | - name: Install Ruff
python -m pip install --upgrade pip run: pip install ruff==0.0.265
pip install pylint - name: Run Ruff
# This lets PyLint check to see if it can resolve imports run: ruff .
- name: Install dependencies lint-js:
run: | runs-on: ubuntu-latest
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" steps:
python launch.py - name: Checkout Code
- name: Analysing the code with pylint uses: actions/checkout@v3
run: | - name: Install Node.js
pylint $(git ls-files '*.py') uses: actions/setup-node@v3
with:
node-version: 18
- run: npm i --ci
- run: npm run lint
...@@ -17,8 +17,14 @@ jobs: ...@@ -17,8 +17,14 @@ jobs:
cache: pip cache: pip
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py
- name: Run tests - name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
- name: Upload main app stdout-stderr - name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()
......
...@@ -34,3 +34,5 @@ notification.mp3 ...@@ -34,3 +34,5 @@ notification.mp3
/test/stderr.txt /test/stderr.txt
/cache.json* /cache.json*
/config_states/ /config_states/
/node_modules
/package-lock.json
\ No newline at end of file
...@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab): ...@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
2. Run `update.bat`.
3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH". 1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
...@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -88,7 +88,7 @@ class LDSR: ...@@ -88,7 +88,7 @@ class LDSR:
x_t = None x_t = None
logs = None logs = None
for n in range(n_runs): for _ in range(n_runs):
if custom_shape is not None: if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
...@@ -110,7 +110,6 @@ class LDSR: ...@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps) diffusion_steps = int(steps)
eta = 1.0 eta = 1.0
down_sample_method = 'Lanczos'
gc.collect() gc.collect()
if torch.cuda.is_available: if torch.cuda.is_available:
...@@ -131,11 +130,11 @@ class LDSR: ...@@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else: else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta) logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"] sample = logs["sample"]
...@@ -158,7 +157,7 @@ class LDSR: ...@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path): def get_cond(selected_path):
example = dict() example = {}
up_f = 4 up_f = 4
c = selected_path.convert('RGB') c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s ...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad() @torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict() log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize ...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except: except Exception:
pass pass
log["sample"] = x_sample log["sample"] = x_sample
......
...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
......
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
import ldm.models.autoencoder import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(self,
...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): ...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed, n_embed,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
image_key="image", image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): ...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): ...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): ...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=""):
...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): ...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): ...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema log["reconstructions_ema"] = xrec_ema
return log return log
...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): ...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def encode(self, x): def encode(self, x):
...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel): ...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel) ldm.models.autoencoder.VQModel = VQModel
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) ldm.models.autoencoder.VQModelInterface = VQModelInterface
...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): ...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
use_ema=True, use_ema=True,
...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): ...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): ...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd["state_dict"] sd = sd["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): ...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.first_stage_key) x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N) N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row) n_row = min(x.shape[0], n_row)
...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): ...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x log["inputs"] = x
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
x_start = x[:n_row] x_start = x[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): ...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", []) ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs) super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except Exception:
self.num_downs = 0 self.num_downs = 0
if not scale_by_std: if not scale_by_std:
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config) self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False self.clip_denoised = False
self.bbox_tokenizer = None self.bbox_tokenizer = None
self.restarted_from_ckpt = False self.restarted_from_ckpt = False
if ckpt_path is not None: if ckpt_path is not None:
...@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim # 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface): if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i], output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize) force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])] for i in range(z.shape[-1])]
...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): ...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
...@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"): if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128) ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64) stride = self.split_input_params["stride"] # eg. (64, 64)
...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial) intermediates.append(x0_partial)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img) intermediates.append(img)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates: if return_intermediates:
return img, intermediates return img, intermediates
...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond, return self.p_sample_loop(cond,
...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
log = dict() log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
force_c_encode=True, force_c_encode=True,
...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows: if plot_diffusion_rows:
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
z_start = z[:n_row] z_start = z[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs) logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation' key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key] dset = self.trainer.datamodule.datasets[key]
...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
import glob
import os import os
import re import re
import torch import torch
...@@ -177,7 +176,7 @@ def load_lora(name, filename): ...@@ -177,7 +176,7 @@ def load_lora(name, filename):
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
...@@ -189,7 +188,7 @@ def load_lora(name, filename): ...@@ -189,7 +188,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0: if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
...@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None): ...@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None):
loaded_loras.clear() loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]): if any(x is None for x in loras_on_disk):
list_available_loras() list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
...@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu ...@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}') print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names) self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward): def lora_forward(module, input, original_forward):
...@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward): ...@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ()) self.lora_current_names = ()
setattr(self, "lora_weights_backup", None) self.lora_weights_backup = None
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
...@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params): ...@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params):
added = [] added = []
for k, v in params.items(): for k in params:
if not k.startswith("AddNet Model "): if not k.startswith("AddNet Model "):
continue continue
......
...@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) ...@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
})) }))
......
...@@ -10,10 +10,9 @@ from tqdm import tqdm ...@@ -10,10 +10,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for k, v in model.named_parameters(): for _, v in model.named_parameters():
v.requires_grad = False v.requires_grad = False
model = model.to(device) model = model.to(device)
return model return model
def on_ui_settings():
import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
script_callbacks.on_ui_settings(on_ui_settings)
...@@ -61,7 +61,9 @@ class WMSA(nn.Module): ...@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns: Returns:
output: tensor shape [b h w c] output: tensor shape [b h w c]
""" """
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1) h_windows = x.size(1)
w_windows = x.size(2) w_windows = x.size(2)
...@@ -85,8 +87,9 @@ class WMSA(nn.Module): ...@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output) output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), if self.type != 'W':
dims=(1, 2)) output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output return output
def relative_embedding(self): def relative_embedding(self):
...@@ -262,4 +265,4 @@ class SCUNet(nn.Module): ...@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
import contextlib
import os import os
import numpy as np import numpy as np
...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except: except Exception:
pass pass
return img return img
...@@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list: for w_idx in w_idx_list:
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
break break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
......
...@@ -644,7 +644,7 @@ class SwinIR(nn.Module): ...@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
...@@ -805,7 +805,7 @@ class SwinIR(nn.Module): ...@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
x = self.check_image_size(x) x = self.check_image_size(x)
self.mean = self.mean.type_as(x) self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range x = (x - self.mean) * self.img_range
...@@ -844,7 +844,7 @@ class SwinIR(nn.Module): ...@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
......
...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module): ...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]): pretrained_window_size=(0, 0)):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module): ...@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask) self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size): def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
H, W = x_size H, W = x_size
...@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module): ...@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask return attn_mask
def forward(self, x, x_size): def forward(self, x, x_size):
H, W = x_size H, W = x_size
...@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module): ...@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else: else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
...@@ -369,7 +369,7 @@ class PatchMerging(nn.Module): ...@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2 flops += H * W * self.dim // 2
return flops return flops
class BasicLayer(nn.Module): class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage. """ A basic Swin Transformer layer for one stage.
...@@ -447,7 +447,7 @@ class BasicLayer(nn.Module): ...@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0) nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
r""" Image to Patch Embedding r""" Image to Patch Embedding
Args: Args:
...@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module): ...@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None: if self.norm is not None:
flops += Ho * Wo * self.embed_dim flops += Ho * Wo * self.embed_dim
return flops return flops
class RSTB(nn.Module): class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB). """Residual Swin Transformer Block (RSTB).
...@@ -531,7 +531,7 @@ class RSTB(nn.Module): ...@@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, drop_path=drop_path,
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -622,7 +622,7 @@ class Upsample(nn.Sequential): ...@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m) super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential): class Upsample_hf(nn.Sequential):
"""Upsample module. """Upsample module.
...@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential): ...@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3)) m.append(nn.PixelShuffle(3))
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m) super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential): class UpsampleOneStep(nn.Sequential):
...@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential): ...@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9 flops = H * W * self.num_feat * 3 * 9
return flops return flops
class Swin2SR(nn.Module): class Swin2SR(nn.Module):
r""" Swin2SR r""" Swin2SR
...@@ -698,8 +698,8 @@ class Swin2SR(nn.Module): ...@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
...@@ -764,7 +764,7 @@ class Swin2SR(nn.Module): ...@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -776,7 +776,7 @@ class Swin2SR(nn.Module): ...@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
) )
self.layers.append(layer) self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf': if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList() self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers): for i_layer in range(self.num_layers):
...@@ -787,7 +787,7 @@ class Swin2SR(nn.Module): ...@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -799,7 +799,7 @@ class Swin2SR(nn.Module): ...@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
) )
self.layers_hf.append(layer) self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction # build the last conv layer in deep feature extraction
...@@ -829,10 +829,10 @@ class Swin2SR(nn.Module): ...@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential( self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1), nn.Conv2d(3, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat) self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf': elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
...@@ -846,7 +846,7 @@ class Swin2SR(nn.Module): ...@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect': elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters) # for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
...@@ -905,7 +905,7 @@ class Swin2SR(nn.Module): ...@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward_features_hf(self, x): def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3]) x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x) x = self.patch_embed(x)
...@@ -919,7 +919,7 @@ class Swin2SR(nn.Module): ...@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
...@@ -951,7 +951,7 @@ class Swin2SR(nn.Module): ...@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x) x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before)) x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before) x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf) x_hf = self.conv_before_upsample_hf(x_hf)
...@@ -977,15 +977,15 @@ class Swin2SR(nn.Module): ...@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x) x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res) x = x + self.conv_last(res)
x = x / self.img_range + self.mean x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux": if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux return x[:, :, :H*self.upscale, :W*self.upscale], aux
elif self.upsampler == "pixelshuffle_hf": elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
else: else:
return x[:, :, :H*self.upscale, :W*self.upscale] return x[:, :, :H*self.upscale, :W*self.upscale]
...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module): ...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
...@@ -1014,4 +1014,4 @@ if __name__ == '__main__': ...@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width)) x = torch.randn((1, 3, height, width))
x = model(x) x = model(x)
print(x.shape) print(x.shape)
\ No newline at end of file
...@@ -4,39 +4,39 @@ ...@@ -4,39 +4,39 @@
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElt) { function checkBrackets(textArea, counterElt) {
var counts = {}; var counts = {};
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => { (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
counts[bracket] = (counts[bracket] || 0) + 1; counts[bracket] = (counts[bracket] || 0) + 1;
}); });
var errors = []; var errors = [];
function checkPair(open, close, kind) { function checkPair(open, close, kind) {
if (counts[open] !== counts[close]) { if (counts[open] !== counts[close]) {
errors.push( errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
); );
}
} }
}
checkPair('(', ')', 'round brackets'); checkPair('(', ')', 'round brackets');
checkPair('[', ']', 'square brackets'); checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets'); checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n'); counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0); counterElt.classList.toggle('error', errors.length !== 0);
} }
function setupBracketChecking(id_prompt, id_counter) { function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter) var counter = gradioApp().getElementById(id_counter);
if (textarea && counter) { if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter)); textarea.addEventListener("input", () => checkBrackets(textarea, counter));
} }
} }
onUiLoaded(function () { onUiLoaded(function() {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter'); setupBracketChecking('img2img_prompt', 'img2img_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
}); });
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term{serach_only}'>{search_term}</span> <span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span> <span class='description'>{description}</span>
......
...@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER ...@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
</pre>
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
<pre>
MIT License
Copyright (c) 2023 Ollin Boer Bohan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre> </pre>
\ No newline at end of file
let currentWidth = null; let currentWidth = null;
let currentHeight = null; let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0); let arFrameTimeout = setTimeout(function() {}, 0);
function dimensionChange(e, is_width, is_height){ function dimensionChange(e, is_width, is_height) {
if(is_width){ if (is_width) {
currentWidth = e.target.value*1.0 currentWidth = e.target.value * 1.0;
} }
if(is_height){ if (is_height) {
currentHeight = e.target.value*1.0 currentHeight = e.target.value * 1.0;
} }
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block"; var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){ if (!inImg2img) {
return; return;
} }
var targetElement = null; var targetElement = null;
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img');
if(tabIndex == 0){ // img2img if (tabIndex == 0) { // img2img
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch } else if (tabIndex == 1) { //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint } else if (tabIndex == 2) { // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch } else if (tabIndex == 3) { // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if (targetElement) {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(!arPreviewRect){ if (!arPreviewRect) {
arPreviewRect = document.createElement('div') arPreviewRect = document.createElement('div');
arPreviewRect.id = "imageARPreview"; arPreviewRect.id = "imageARPreview";
gradioApp().appendChild(arPreviewRect) gradioApp().appendChild(arPreviewRect);
} }
var viewportOffset = targetElement.getBoundingClientRect(); var viewportOffset = targetElement.getBoundingClientRect();
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
var scaledx = targetElement.naturalWidth*viewportscale var scaledx = targetElement.naturalWidth * viewportscale;
var scaledy = targetElement.naturalHeight*viewportscale var scaledy = targetElement.naturalHeight * viewportscale;
var cleintRectTop = (viewportOffset.top+window.scrollY) var cleintRectTop = (viewportOffset.top + window.scrollY);
var cleintRectLeft = (viewportOffset.left+window.scrollX) var cleintRectLeft = (viewportOffset.left + window.scrollX);
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight ) var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
var arscaledx = currentWidth*arscale var arscaledx = currentWidth * arscale;
var arscaledy = currentHeight*arscale var arscaledy = currentHeight * arscale;
var arRectTop = cleintRectCentreY-(arscaledy/2) var arRectTop = cleintRectCentreY - (arscaledy / 2);
var arRectLeft = cleintRectCentreX-(arscaledx/2) var arRectLeft = cleintRectCentreX - (arscaledx / 2);
var arRectWidth = arscaledx var arRectWidth = arscaledx;
var arRectHeight = arscaledy var arRectHeight = arscaledy;
arPreviewRect.style.top = arRectTop+'px'; arPreviewRect.style.top = arRectTop + 'px';
arPreviewRect.style.left = arRectLeft+'px'; arPreviewRect.style.left = arRectLeft + 'px';
arPreviewRect.style.width = arRectWidth+'px'; arPreviewRect.style.width = arRectWidth + 'px';
arPreviewRect.style.height = arRectHeight+'px'; arPreviewRect.style.height = arRectHeight + 'px';
clearTimeout(arFrameTimeout); clearTimeout(arFrameTimeout);
arFrameTimeout = setTimeout(function(){ arFrameTimeout = setTimeout(function() {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
},2000); }, 2000);
arPreviewRect.style.display = 'block'; arPreviewRect.style.display = 'block';
} }
} }
onUiUpdate(function(){ onUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(arPreviewRect){ if (arPreviewRect) {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
} }
var tabImg2img = gradioApp().querySelector("#tab_img2img"); var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) { if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block"; var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){ if (inImg2img) {
let inputs = gradioApp().querySelectorAll('input'); let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){ inputs.forEach(function(e) {
var is_width = e.parentElement.id == "img2img_width" var is_width = e.parentElement.id == "img2img_width";
var is_height = e.parentElement.id == "img2img_height" var is_height = e.parentElement.id == "img2img_height";
if((is_width || is_height) && !e.classList.contains('scrollwatch')){ if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) e.addEventListener('input', function(e) {
e.classList.add('scrollwatch') dimensionChange(e, is_width, is_height);
} });
if(is_width){ e.classList.add('scrollwatch');
currentWidth = e.value*1.0 }
} if (is_width) {
if(is_height){ currentWidth = e.value * 1.0;
currentHeight = e.value*1.0 }
} if (is_height) {
}) currentHeight = e.value * 1.0;
} }
} });
}); }
}
});
This diff is collapsed.
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
function isValidImageList( files ) { function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
} }
function dropReplaceImage( imgWrap, files ) { function dropReplaceImage(imgWrap, files) {
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
...@@ -14,44 +14,44 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -14,44 +14,44 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
if ( files.length === 0 ) { if (files.length === 0) {
files = new DataTransfer(); files = new DataTransfer();
files.items.add(tmpFile); files.items.add(tmpFile);
fileInput.files = files.files; fileInput.files = files.files;
} else { } else {
fileInput.files = files; fileInput.files = files;
} }
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
if ( imgWrap.closest('#pnginfo_image') ) { if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish // special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch; const oldFetch = window.fetch;
window.fetch = async (input, options) => { window.fetch = async(input, options) => {
const response = await oldFetch(input, options); const response = await oldFetch(input, options);
if ( 'api/predict/' === input ) { if ('api/predict/' === input) {
const content = await response.text(); const content = await response.text();
window.fetch = oldFetch; window.fetch = oldFetch;
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
return new Response(content, { return new Response(content, {
status: response.status, status: response.status,
statusText: response.statusText, statusText: response.statusText,
headers: response.headers headers: response.headers
}) });
} }
return response; return response;
}; };
} else { } else {
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
} }
} }
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -65,33 +65,34 @@ window.document.addEventListener('drop', e => { ...@@ -65,33 +65,34 @@ window.document.addEventListener('drop', e => {
return; return;
} }
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if (!imgWrap) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
e.preventDefault(); e.preventDefault();
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
dropReplaceImage( imgWrap, files ); dropReplaceImage(imgWrap, files);
}); });
window.addEventListener('paste', e => { window.addEventListener('paste', e => {
const files = e.clipboardData.files; const files = e.clipboardData.files;
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(el => uiElementIsVisible(el)); .filter(el => uiElementIsVisible(el));
if ( ! visibleImageFields.length ) { if (!visibleImageFields.length) {
return; return;
} }
const firstFreeImageField = visibleImageFields const firstFreeImageField = visibleImageFields
.filter(el => el.querySelector('input[type=file]'))?.[0]; .filter(el => el.querySelector('input[type=file]'))?.[0];
dropReplaceImage( dropReplaceImage(
firstFreeImageField ? firstFreeImageField ?
firstFreeImageField : firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1] visibleImageFields[visibleImageFields.length - 1]
, files ); , files
);
}); });
function keyupEditAttention(event){ function keyupEditAttention(event) {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (!(event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp" let isPlus = event.key == "ArrowUp";
let isMinus = event.key == "ArrowDown" let isMinus = event.key == "ArrowDown";
if (!isPlus && !isMinus) return; if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
let text = target.value; let text = target.value;
function selectCurrentParenthesisBlock(OPEN, CLOSE){ function selectCurrentParenthesisBlock(OPEN, CLOSE) {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart); const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false; if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE); let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
} }
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = text.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false; if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN); let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) { while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1); afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
} }
if (beforeParen === -1 || afterParen === -1) return false; if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis // Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":"); const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1; selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon; selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }
function selectCurrentWord(){ function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t"; const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining // seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--; selectionStart--;
} }
// seek forward to find end // seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) { while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++; selectionEnd++;
} }
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }
// If the user hasn't selected anything, let's select their current parenthesis block or word // If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord(); selectCurrentWord();
} }
event.preventDefault(); event.preventDefault();
var closeCharacter = ')' var closeCharacter = ')';
var delta = opts.keyedit_precision_attention var delta = opts.keyedit_precision_attention;
if (selectionStart > 0 && text[selectionStart - 1] == '<'){ if (selectionStart > 0 && text[selectionStart - 1] == '<') {
closeCharacter = '>' closeCharacter = '>';
delta = opts.keyedit_precision_extra delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") { } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
// do not include spaces at the end // do not include spaces at the end
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1; selectionEnd -= 1;
} }
if(selectionStart == selectionEnd){ if (selectionStart == selectionEnd) {
return return;
} }
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1; selectionStart += 1;
selectionEnd += 1; selectionEnd += 1;
} }
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0" if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--; selectionStart--;
selectionEnd--; selectionEnd--;
} else { } else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
} }
target.focus(); target.focus();
target.value = text; target.value = text;
target.selectionStart = selectionStart; target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd; target.selectionEnd = selectionEnd;
updateInput(target) updateInput(target);
} }
addEventListener('keydown', (event) => { addEventListener('keydown', (event) => {
keyupEditAttention(event); keyupEditAttention(event);
}); });
function extensions_apply(_disabled_list, _update_list, disable_all){ function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = [] var disable = [];
var update = [] var update = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if(x.name.startsWith("enable_") && ! x.checked) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7)) disable.push(x.name.substring(7));
}
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substring(7)) if (x.name.startsWith("update_") && x.checked) {
}) update.push(x.name.substring(7));
}
restart_reload() });
return [JSON.stringify(disable), JSON.stringify(update), disable_all] restart_reload();
}
return [JSON.stringify(disable), JSON.stringify(update), disable_all];
function extensions_check(){ }
var disable = []
function extensions_check() {
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ var disable = [];
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substring(7)) gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
}) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7));
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ }
x.innerHTML = "Loading..." });
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading...";
var id = randomId() });
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
}) var id = randomId();
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
return [id, JSON.stringify(disable)]
} });
function install_extension_from_index(button, url){ return [id, JSON.stringify(disable)];
button.disabled = "disabled" }
button.value = "Installing..."
function install_extension_from_index(button, url) {
var textarea = gradioApp().querySelector('#extension_to_install textarea') button.disabled = "disabled";
textarea.value = url button.value = "Installing...";
updateInput(textarea)
var textarea = gradioApp().querySelector('#extension_to_install textarea');
gradioApp().querySelector('#install_extension_button').click() textarea.value = url;
} updateInput(textarea);
function config_state_confirm_restore(_, config_state_name, config_restore_type) { gradioApp().querySelector('#install_extension_button').click();
if (config_state_name == "Current") { }
return [false, config_state_name, config_restore_type];
} function config_state_confirm_restore(_, config_state_name, config_restore_type) {
let restored = ""; if (config_state_name == "Current") {
if (config_restore_type == "extensions") { return [false, config_state_name, config_restore_type];
restored = "all saved extension versions"; }
} else if (config_restore_type == "webui") { let restored = "";
restored = "the webui version"; if (config_restore_type == "extensions") {
} else { restored = "all saved extension versions";
restored = "the webui version and all saved extension versions"; } else if (config_restore_type == "webui") {
} restored = "the webui version";
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + "."); } else {
if (confirmed) { restored = "the webui version and all saved extension versions";
restart_reload(); }
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
x.innerHTML = "Loading..." if (confirmed) {
}) restart_reload();
} gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
return [confirmed, config_state_name, config_restore_type]; x.innerHTML = "Loading...";
} });
}
return [confirmed, config_state_name, config_restore_type];
}
This diff is collapsed.
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined; let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){ onUiUpdate(function() {
if (!txt2img_gallery) { if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img") txt2img_gallery = attachGalleryListeners("txt2img");
} }
if (!img2img_gallery) { if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img") img2img_gallery = attachGalleryListeners("img2img");
} }
if (!modal) { if (!modal) {
modal = gradioApp().getElementById('lightboxModal') modal = gradioApp().getElementById('lightboxModal');
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
} }
}); });
let modalObserver = new MutationObserver(function(mutations) { let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) { mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click() gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
}); }
});
}); });
function attachGalleryListeners(tab_name) { function attachGalleryListeners(tab_name) {
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery') var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => { gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click() gradioApp().getElementById(tab_name + "_generation_info_button").click();
}); }
return gallery; });
return gallery;
} }
// mouseover tooltips for various UI elements // mouseover tooltips for various UI elements
titles = { var titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image", "Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models", "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", "\u{1F4D0}": "Auto detect size from img2img",
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)", "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
...@@ -40,7 +41,7 @@ titles = { ...@@ -40,7 +41,7 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Skip": "Stop processing current image and continue processing.", "Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.", "Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
...@@ -66,8 +67,8 @@ titles = { ...@@ -66,8 +67,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
...@@ -96,7 +97,7 @@ titles = { ...@@ -96,7 +97,7 @@ titles = {
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A", "No interpolation": "Result = A",
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
...@@ -113,38 +114,55 @@ titles = { ...@@ -113,38 +114,55 @@ titles = {
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
} };
function updateTooltipForSpan(span) {
if (span.title) return; // already has a title
onUiUpdate(function(){ let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
if (span.title) return; // already has a title
let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; if (!tooltip) {
tooltip = localization[titles[span.value]] || titles[span.value];
}
if(!tooltip){ if (!tooltip) {
tooltip = localization[titles[span.value]] || titles[span.value]; for (const c of span.classList) {
} if (c in titles) {
tooltip = localization[titles[c]] || titles[c];
break;
}
}
}
if(!tooltip){ if (tooltip) {
for (const c of span.classList) { span.title = tooltip;
if (c in titles) { }
tooltip = localization[titles[c]] || titles[c]; }
break;
}
}
}
if(tooltip){ function updateTooltipForSelect(select) {
span.title = tooltip; if (select.onchange != null) return;
}
})
gradioApp().querySelectorAll('select').forEach(function(select){ select.onchange = function() {
if (select.onchange != null) return; select.title = localization[titles[select.value]] || titles[select.value] || "";
};
}
select.onchange = function(){ var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
select.title = localization[titles[select.value]] || titles[select.value] || "";
} onUiUpdate(function(m) {
}) m.forEach(function(record) {
}) record.addedNodes.forEach(function(node) {
if (observedTooltipElements[node.tagName]) {
updateTooltipForSpan(node);
}
if (node.tagName == "SELECT") {
updateTooltipForSelect(node);
}
if (node.querySelectorAll) {
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
node.querySelectorAll('select').forEach(updateTooltipForSelect);
}
});
});
});
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
function setInactive(elem, inactive){ function setInactive(elem, inactive) {
elem.classList.toggle('inactive', !!inactive) elem.classList.toggle('inactive', !!inactive);
} }
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
} }
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
*/ */
function imageMaskResize() { function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) { if (!canvases.length) {
canvases_fixed = false; // TODO: this is unused..? window.removeEventListener('resize', imageMaskResize);
window.removeEventListener( 'resize', imageMaskResize ); return;
return;
} }
const wrapper = canvases[0].closest('.touch-none'); const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling; const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) { if (!previewImage.complete) {
previewImage.addEventListener( 'load', imageMaskResize); previewImage.addEventListener('load', imageMaskResize);
return; return;
} }
...@@ -24,15 +23,15 @@ function imageMaskResize() { ...@@ -24,15 +23,15 @@ function imageMaskResize() {
const nh = previewImage.naturalHeight; const nh = previewImage.naturalHeight;
const portrait = nh > nw; const portrait = nh > nw;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`; wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`; wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`; wrapper.style.left = `0px`;
wrapper.style.top = `0px`; wrapper.style.top = `0px`;
canvases.forEach( c => { canvases.forEach(c => {
c.style.width = c.style.height = ''; c.style.width = c.style.height = '';
c.style.maxWidth = '100%'; c.style.maxWidth = '100%';
c.style.maxHeight = '100%'; c.style.maxHeight = '100%';
...@@ -41,4 +40,4 @@ function imageMaskResize() { ...@@ -41,4 +40,4 @@ function imageMaskResize() {
} }
onUiUpdate(imageMaskResize); onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize); window.addEventListener('resize', imageMaskResize);
window.onload = (function(){ window.onload = (function() {
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return; if (target.placeholder.indexOf("Prompt") == -1) return;
...@@ -10,7 +10,7 @@ window.onload = (function(){ ...@@ -10,7 +10,7 @@ window.onload = (function(){
const imgParent = gradioApp().getElementById(prompt_target); const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]'); const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
fileInput.files = files; fileInput.files = files;
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
......
This diff is collapsed.
window.addEventListener('gamepadconnected', (e) => { window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index; const index = e.gamepad.index;
let isWaiting = false; let isWaiting = false;
setInterval(async () => { setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return; if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index]; const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0]; const xValue = gamepad.axes[0];
...@@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => { ...@@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => {
} }
if (isWaiting) { if (isWaiting) {
await sleepUntil(() => { await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0] const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) { if (xValue < 0.3 && xValue > -0.3) {
return true; return true;
} }
......
// localization = {} -- the dict with translations is created by the backend // localization = {} -- the dict with translations is created by the backend
ignore_ids_for_localization={ var ignore_ids_for_localization = {
setting_sd_hypernetwork: 'OPTION', setting_sd_hypernetwork: 'OPTION',
setting_sd_model_checkpoint: 'OPTION', setting_sd_model_checkpoint: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION', modelmerger_primary_model_name: 'OPTION',
modelmerger_primary_model_name: 'OPTION', modelmerger_secondary_model_name: 'OPTION',
modelmerger_secondary_model_name: 'OPTION', modelmerger_tertiary_model_name: 'OPTION',
modelmerger_tertiary_model_name: 'OPTION', train_embedding: 'OPTION',
train_embedding: 'OPTION', train_hypernetwork: 'OPTION',
train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION',
txt2img_styles: 'OPTION', img2img_styles: 'OPTION',
img2img_styles: 'OPTION', setting_random_artist_categories: 'SPAN',
setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN',
setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN', extras_upscaler_1: 'SPAN',
extras_upscaler_1: 'SPAN', extras_upscaler_2: 'SPAN',
extras_upscaler_2: 'SPAN', };
}
var re_num = /^[.\d]+$/;
re_num = /^[\.\d]+$/ var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
var original_lines = {};
original_lines = {} var translated_lines = {};
translated_lines = {}
function hasLocalization() {
function hasLocalization() { return window.localization && Object.keys(window.localization).length > 0;
return window.localization && Object.keys(window.localization).length > 0; }
}
function textNodesUnder(el) {
function textNodesUnder(el){ var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); while ((n = walk.nextNode())) a.push(n);
while(n=walk.nextNode()) a.push(n); return a;
return a; }
}
function canBeTranslated(node, text) {
function canBeTranslated(node, text){ if (!text) return false;
if(! text) return false; if (!node.parentElement) return false;
if(! node.parentElement) return false;
var parentType = node.parentElement.nodeName;
var parentType = node.parentElement.nodeName if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType == 'OPTION' || parentType == 'SPAN') {
if (parentType=='OPTION' || parentType=='SPAN'){ var pnode = node;
var pnode = node for (var level = 0; level < 4; level++) {
for(var level=0; level<4; level++){ pnode = pnode.parentElement;
pnode = pnode.parentElement if (!pnode) break;
if(! pnode) break;
if (ignore_ids_for_localization[pnode.id] == parentType) return false;
if(ignore_ids_for_localization[pnode.id] == parentType) return false; }
} }
}
if (re_num.test(text)) return false;
if(re_num.test(text)) return false; if (re_emoji.test(text)) return false;
if(re_emoji.test(text)) return false; return true;
return true }
}
function getTranslation(text) {
function getTranslation(text){ if (!text) return undefined;
if(! text) return undefined
if (translated_lines[text] === undefined) {
if(translated_lines[text] === undefined){ original_lines[text] = 1;
original_lines[text] = 1 }
}
var tl = localization[text];
tl = localization[text] if (tl !== undefined) {
if(tl !== undefined){ translated_lines[tl] = 1;
translated_lines[tl] = 1 }
}
return tl;
return tl }
}
function processTextNode(node) {
function processTextNode(node){ var text = node.textContent.trim();
var text = node.textContent.trim()
if (!canBeTranslated(node, text)) return;
if(! canBeTranslated(node, text)) return
var tl = getTranslation(text);
tl = getTranslation(text) if (tl !== undefined) {
if(tl !== undefined){ node.textContent = tl;
node.textContent = tl }
} }
}
function processNode(node) {
function processNode(node){ if (node.nodeType == 3) {
if(node.nodeType == 3){ processTextNode(node);
processTextNode(node) return;
return }
}
if (node.title) {
if(node.title){ let tl = getTranslation(node.title);
tl = getTranslation(node.title) if (tl !== undefined) {
if(tl !== undefined){ node.title = tl;
node.title = tl }
} }
}
if (node.placeholder) {
if(node.placeholder){ let tl = getTranslation(node.placeholder);
tl = getTranslation(node.placeholder) if (tl !== undefined) {
if(tl !== undefined){ node.placeholder = tl;
node.placeholder = tl }
} }
}
textNodesUnder(node).forEach(function(node) {
textNodesUnder(node).forEach(function(node){ processTextNode(node);
processTextNode(node) });
}) }
}
function dumpTranslations() {
function dumpTranslations(){ if (!hasLocalization()) {
if(!hasLocalization()) { // If we don't have any localization,
// If we don't have any localization, // we will not have traversed the app to find
// we will not have traversed the app to find // original_lines, so do that now.
// original_lines, so do that now. processNode(gradioApp());
processNode(gradioApp()); }
} var dumped = {};
var dumped = {} if (localization.rtl) {
if (localization.rtl) { dumped.rtl = true;
dumped.rtl = true; }
}
for (const text in original_lines) {
for (const text in original_lines) { if (dumped[text] !== undefined) continue;
if(dumped[text] !== undefined) continue; dumped[text] = localization[text] || text;
dumped[text] = localization[text] || text; }
}
return dumped;
return dumped; }
}
function download_localization() {
function download_localization() { var text = JSON.stringify(dumpTranslations(), null, 4);
var text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
var element = document.createElement('a'); element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); element.setAttribute('download', "localization.json");
element.setAttribute('download', "localization.json"); element.style.display = 'none';
element.style.display = 'none'; document.body.appendChild(element);
document.body.appendChild(element);
element.click();
element.click();
document.body.removeChild(element);
document.body.removeChild(element); }
}
document.addEventListener("DOMContentLoaded", function() {
document.addEventListener("DOMContentLoaded", function () { if (!hasLocalization()) {
if (!hasLocalization()) { return;
return; }
}
onUiUpdate(function(m) {
onUiUpdate(function (m) { m.forEach(function(mutation) {
m.forEach(function (mutation) { mutation.addedNodes.forEach(function(node) {
mutation.addedNodes.forEach(function (node) { processNode(node);
processNode(node) });
}) });
}); });
})
processNode(gradioApp());
processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left,
if (localization.rtl) { // if the language is from right to left, (new MutationObserver((mutations, observer) => { // wait for the style to load
(new MutationObserver((mutations, observer) => { // wait for the style to load mutations.forEach(mutation => {
mutations.forEach(mutation => { mutation.addedNodes.forEach(node => {
mutation.addedNodes.forEach(node => { if (node.tagName === 'STYLE') {
if (node.tagName === 'STYLE') { observer.disconnect();
observer.disconnect();
for (const x of node.sheet.rules) { // find all rtl media rules
for (const x of node.sheet.rules) { // find all rtl media rules if (Array.from(x.media || []).includes('rtl')) {
if (Array.from(x.media || []).includes('rtl')) { x.media.appendMedium('all'); // enable them
x.media.appendMedium('all'); // enable them }
} }
} }
} });
}) });
}); })).observe(gradioApp(), {childList: true});
})).observe(gradioApp(), { childList: true }); }
} });
})
...@@ -4,14 +4,14 @@ let lastHeadImg = null; ...@@ -4,14 +4,14 @@ let lastHeadImg = null;
let notificationButton = null; let notificationButton = null;
onUiUpdate(function(){ onUiUpdate(function() {
if(notificationButton == null){ if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications') notificationButton = gradioApp().getElementById('request_notifications');
if(notificationButton != null){ if (notificationButton != null) {
notificationButton.addEventListener('click', () => { notificationButton.addEventListener('click', () => {
void Notification.requestPermission(); void Notification.requestPermission();
},true); }, true);
} }
} }
...@@ -42,7 +42,7 @@ onUiUpdate(function(){ ...@@ -42,7 +42,7 @@ onUiUpdate(function(){
} }
); );
notification.onclick = function(_){ notification.onclick = function(_) {
parent.focus(); parent.focus();
this.close(); this.close();
}; };
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(){ function rememberGallerySelection() {
} }
function getGallerySelectedIndex(){ function getGallerySelectedIndex() {
} }
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler) {
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
xhr.open("POST", url, true); xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json"); xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () { xhr.onreadystatechange = function() {
if (xhr.readyState === 4) { if (xhr.readyState === 4) {
if (xhr.status === 200) { if (xhr.status === 200) {
try { try {
var js = JSON.parse(xhr.responseText); var js = JSON.parse(xhr.responseText);
handler(js) handler(js);
} catch (error) { } catch (error) {
console.error(error); console.error(error);
errorHandler() errorHandler();
} }
} else{ } else {
errorHandler() errorHandler();
} }
} }
}; };
...@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){ ...@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){
xhr.send(js); xhr.send(js);
} }
function pad2(x){ function pad2(x) {
return x<10 ? '0'+x : x return x < 10 ? '0' + x : x;
} }
function formatTime(secs){ function formatTime(secs) {
if(secs > 3600){ if (secs > 3600) {
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
} else if(secs > 60){ } else if (secs > 60) {
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
} else{ } else {
return Math.floor(secs) + "s" return Math.floor(secs) + "s";
} }
} }
function setTitle(progress){ function setTitle(progress) {
var title = 'Stable Diffusion' var title = 'Stable Diffusion';
if(opts.show_progress_in_title && progress){ if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title; title = '[' + progress.trim() + '] ' + title;
} }
if(document.title != title){ if (document.title != title) {
document.title = title; document.title = title;
} }
} }
function randomId(){ function randomId() {
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
} }
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update // calls onProgress every time there is a progress update
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
var dateStart = new Date() var dateStart = new Date();
var wasEverActive = false var wasEverActive = false;
var parentProgressbar = progressbarContainer.parentNode var parentProgressbar = progressbarContainer.parentNode;
var parentGallery = gallery ? gallery.parentNode : null var parentGallery = gallery ? gallery.parentNode : null;
var divProgress = document.createElement('div') var divProgress = document.createElement('div');
divProgress.className='progressDiv' divProgress.className = 'progressDiv';
divProgress.style.display = opts.show_progressbar ? "block" : "none" divProgress.style.display = opts.show_progressbar ? "block" : "none";
var divInner = document.createElement('div') var divInner = document.createElement('div');
divInner.className='progress' divInner.className = 'progress';
divProgress.appendChild(divInner) divProgress.appendChild(divInner);
parentProgressbar.insertBefore(divProgress, progressbarContainer) parentProgressbar.insertBefore(divProgress, progressbarContainer);
if(parentGallery){ if (parentGallery) {
var livePreview = document.createElement('div') var livePreview = document.createElement('div');
livePreview.className='livePreview' livePreview.className = 'livePreview';
parentGallery.insertBefore(livePreview, gallery) parentGallery.insertBefore(livePreview, gallery);
} }
var removeProgressBar = function(){ var removeProgressBar = function() {
setTitle("") setTitle("");
parentProgressbar.removeChild(divProgress) parentProgressbar.removeChild(divProgress);
if(parentGallery) parentGallery.removeChild(livePreview) if (parentGallery) parentGallery.removeChild(livePreview);
atEnd() atEnd();
} };
var fun = function(id_task, id_live_preview){ var fun = function(id_task, id_live_preview) {
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
if(res.completed){ if (res.completed) {
removeProgressBar() removeProgressBar();
return return;
} }
var rect = progressbarContainer.getBoundingClientRect() var rect = progressbarContainer.getBoundingClientRect();
if(rect.width){ if (rect.width) {
divProgress.style.width = rect.width + "px"; divProgress.style.width = rect.width + "px";
} }
let progressText = "" let progressText = "";
divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.width = ((res.progress || 0) * 100.0) + '%';
divInner.style.background = res.progress ? "" : "transparent" divInner.style.background = res.progress ? "" : "transparent";
if(res.progress > 0){ if (res.progress > 0) {
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
} }
if(res.eta){ if (res.eta) {
progressText += " ETA: " + formatTime(res.eta) progressText += " ETA: " + formatTime(res.eta);
} }
setTitle(progressText) setTitle(progressText);
if(res.textinfo && res.textinfo.indexOf("\n") == -1){ if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
progressText = res.textinfo + " " + progressText progressText = res.textinfo + " " + progressText;
} }
divInner.textContent = progressText divInner.textContent = progressText;
var elapsedFromStart = (new Date() - dateStart) / 1000 var elapsedFromStart = (new Date() - dateStart) / 1000;
if(res.active) wasEverActive = true; if (res.active) wasEverActive = true;
if(! res.active && wasEverActive){ if (!res.active && wasEverActive) {
removeProgressBar() removeProgressBar();
return return;
} }
if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){ if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
removeProgressBar() removeProgressBar();
return return;
} }
if(res.live_preview && gallery){ if (res.live_preview && gallery) {
var rect = gallery.getBoundingClientRect() rect = gallery.getBoundingClientRect();
if(rect.width){ if (rect.width) {
livePreview.style.width = rect.width + "px" livePreview.style.width = rect.width + "px";
livePreview.style.height = rect.height + "px" livePreview.style.height = rect.height + "px";
} }
var img = new Image(); var img = new Image();
img.onload = function() { img.onload = function() {
livePreview.appendChild(img) livePreview.appendChild(img);
if(livePreview.childElementCount > 2){ if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild) livePreview.removeChild(livePreview.firstElementChild);
} }
} };
img.src = res.live_preview; img.src = res.live_preview;
} }
if(onProgress){ if (onProgress) {
onProgress(res) onProgress(res);
} }
setTimeout(() => { setTimeout(() => {
fun(id_task, res.id_live_preview); fun(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500) }, opts.live_preview_refresh_period || 500);
}, function(){ }, function() {
removeProgressBar() removeProgressBar();
}) });
} };
fun(id_task, 0) fun(id_task, 0);
} }
function start_training_textual_inversion(){ function start_training_textual_inversion() {
gradioApp().querySelector('#ti_error').innerHTML='' gradioApp().querySelector('#ti_error').innerHTML = '';
var id = randomId() var id = randomId();
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
}) });
var res = args_to_array(arguments) var res = args_to_array(arguments);
res[0] = id res[0] = id;
return res return res;
} }
This diff is collapsed.
// various hints and extra info for the settings tab // various hints and extra info for the settings tab
onUiLoaded(function(){ var settingsHintsSetup = false;
createLink = function(elem_id, text, href){
var a = document.createElement('A') onOptionsChanged(function() {
a.textContent = text if (settingsHintsSetup) return;
a.target = '_blank'; settingsHintsSetup = true;
elem = gradioApp().querySelector('#'+elem_id) gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
elem.insertBefore(a, elem.querySelector('label')) var name = div.id.substr(8);
var commentBefore = opts._comments_before[name];
return a var commentAfter = opts._comments_after[name];
}
if (!commentBefore && !commentAfter) return;
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory" var span = null;
if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){ else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
requestGet("./internal/quicksettings-hint", {}, function(data){ else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
var table = document.createElement('table') else span = div.querySelector('label span').firstChild;
table.className = 'settings-value-table'
if (!span) return;
data.forEach(function(obj){
var tr = document.createElement('tr') if (commentBefore) {
var td = document.createElement('td') var comment = document.createElement('DIV');
td.textContent = obj.name comment.className = 'settings-comment';
tr.appendChild(td) comment.innerHTML = commentBefore;
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
var td = document.createElement('td') span.parentElement.insertBefore(comment, span);
td.textContent = obj.label span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
tr.appendChild(td) }
if (commentAfter) {
table.appendChild(tr) comment = document.createElement('DIV');
}) comment.className = 'settings-comment';
comment.innerHTML = commentAfter;
popup(table); span.parentElement.insertBefore(comment, span.nextSibling);
}) span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
}); }
}) });
});
function settingsHintsShowQuicksettings() {
requestGet("./internal/quicksettings-hint", {}, function(data) {
var table = document.createElement('table');
table.className = 'settings-value-table';
data.forEach(function(obj) {
var tr = document.createElement('tr');
var td = document.createElement('td');
td.textContent = obj.name;
tr.appendChild(td);
td = document.createElement('td');
td.textContent = obj.label;
tr.appendChild(td);
table.appendChild(tr);
});
popup(table);
});
}
...@@ -3,25 +3,23 @@ import subprocess ...@@ -3,25 +3,23 @@ import subprocess
import os import os
import sys import sys
import importlib.util import importlib.util
import shlex
import platform import platform
import json import json
from functools import lru_cache
from modules import cmd_args from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
stored_git_tag = None
dir_repos = "repositories" dir_repos = "repositories"
# Whether to default to printing command output
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
...@@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning. ...@@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning.
""") """)
@lru_cache()
def commit_hash(): def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try: try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip() return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
stored_commit_hash = "<none>" return "<none>"
return stored_commit_hash
@lru_cache()
def git_tag(): def git_tag():
global stored_git_tag
if stored_git_tag is not None:
return stored_git_tag
try: try:
stored_git_tag = run(f"{git} describe --tags").strip() return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
stored_git_tag = "<none>" return "<none>"
return stored_git_tag
def run(command, desc=None, errdesc=None, custom_env=None, live=False): def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None: if desc is not None:
print(desc) print(desc)
if live: run_kwargs = {
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) "args": command,
if result.returncode != 0: "shell": True,
raise RuntimeError(f"""{errdesc or 'Error running command'}. "env": os.environ if custom_env is None else custom_env,
Command: {command} "encoding": 'utf8',
Error code: {result.returncode}""") "errors": 'ignore',
}
return "" if not live:
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) result = subprocess.run(**run_kwargs)
if result.returncode != 0: if result.returncode != 0:
error_bits = [
f"{errdesc or 'Error running command'}.",
f"Command: {command}",
f"Error code: {result.returncode}",
]
if result.stdout:
error_bits.append(f"stdout: {result.stdout}")
if result.stderr:
error_bits.append(f"stderr: {result.stderr}")
raise RuntimeError("\n".join(error_bits))
message = f"""{errdesc or 'Error running command'}. return (result.stdout or "")
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0
def is_installed(package): def is_installed(package):
...@@ -131,11 +116,7 @@ def repo_dir(name): ...@@ -131,11 +116,7 @@ def repo_dir(name):
return os.path.join(script_path, dir_repos, name) return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None): def run_pip(command, desc=None, live=default_command_live):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(command, desc=None, live=False):
if args.skip_install: if args.skip_install:
return return
...@@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False): ...@@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False):
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code): def check_run_python(code: str) -> bool:
return check_run(f'"{python}" -c "{code}"') result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
return result.returncode == 0
def git_clone(url, dir, name, commithash=None): def git_clone(url, dir, name, commithash=None):
...@@ -237,13 +219,14 @@ def run_extensions_installers(settings_file): ...@@ -237,13 +219,14 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
...@@ -270,8 +253,11 @@ def prepare_environment(): ...@@ -270,8 +253,11 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not args.skip_torch_cuda_test: if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") raise RuntimeError(
'Torch is not able to use GPU; '
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
)
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
...@@ -319,7 +305,7 @@ def prepare_environment(): ...@@ -319,7 +305,7 @@ def prepare_environment():
if args.update_all_extensions: if args.update_all_extensions:
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
......
This diff is collapsed.
...@@ -223,8 +223,9 @@ for key in _options: ...@@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'): if(_options[key].dest != 'help'):
flag = _options[key] flag = _options[key]
_type = str _type = str
if _options[key].default is not None: _type = type(_options[key].default) if _options[key].default is not None:
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) _type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags) FlagsModel = create_model("Flags", **flags)
...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel): ...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ScriptsList(BaseModel): class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
class ScriptArg(BaseModel):
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
import argparse import argparse
import json import json
import os import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -103,4 +103,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra ...@@ -103,4 +103,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
\ No newline at end of file parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math import math
import numpy as np
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List from typing import Optional
from modules.codeformer.vqgan_arch import * from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5): def calc_mean_std(feat, eps=1e-5):
...@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module): ...@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): query_pos: Optional[Tensor] = None):
# self attention # self attention
tgt2 = self.norm1(tgt) tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos) q = k = self.with_pos_embed(tgt2, query_pos)
...@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module): ...@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'], connect_list=('32', '64', '128', '256'),
fix_modules=['quantize','generator']): fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None: if fix_modules is not None:
...@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder): ...@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd) self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer # transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)]) for _ in range(self.n_layers)])
# logits_predict head # logits_predict head
self.idx_pred_layer = nn.Sequential( self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False)) nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = { self.channels = {
'16': 512, '16': 512,
'32': 256, '32': 256,
...@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder): ...@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {} enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks): for i, block in enumerate(self.encoder.blocks):
x = block(x) x = block(x)
if i in out_list: if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone() enc_feat_dict[str(x.shape[-1])] = x.clone()
...@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder): ...@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks): for i, block in enumerate(self.generator.blocks):
x = block(x) x = block(x)
if i in fuse_list: # fuse after i-th block if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1]) f_size = str(x.shape[-1])
if w>0: if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x out = x
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat return out, logits, lq_feat
\ No newline at end of file
This diff is collapsed.
...@@ -33,11 +33,9 @@ def setup_model(dirname): ...@@ -33,11 +33,9 @@ def setup_model(dirname):
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url from basicsr.utils import img2tensor, tensor2img
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
...@@ -96,7 +94,7 @@ def setup_model(dirname): ...@@ -96,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face() self.face_helper.align_warp_face()
for idx, cropped_face in enumerate(self.face_helper.cropped_faces): for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
......
...@@ -14,7 +14,7 @@ from collections import OrderedDict ...@@ -14,7 +14,7 @@ from collections import OrderedDict
import git import git
from modules import shared, extensions from modules import shared, extensions
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict() all_config_states = OrderedDict()
...@@ -35,7 +35,7 @@ def list_config_states(): ...@@ -35,7 +35,7 @@ def list_config_states():
j["filepath"] = path j["filepath"] = path
config_states.append(j) config_states.append(j)
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)) config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states: for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"])) timestamp = time.asctime(time.gmtime(cs["created_at"]))
...@@ -83,6 +83,8 @@ def get_extension_config(): ...@@ -83,6 +83,8 @@ def get_extension_config():
ext_config = {} ext_config = {}
for ext in extensions.extensions: for ext in extensions.extensions:
ext.read_info_from_repo()
entry = { entry = {
"name": ext.name, "name": ext.name,
"path": ext.path, "path": ext.path,
......
This diff is collapsed.
...@@ -65,7 +65,7 @@ def enable_tf32(): ...@@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from modules import extra_networks, shared, extra_networks from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment