Commit 993dd9a8 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into patch-1

parents ff6acd35 d7d6e8cf
extensions
extensions-disabled
repositories
venv
\ No newline at end of file
/* global module */
module.exports = {
env: {
browser: true,
es2021: true,
},
extends: "eslint:recommended",
parserOptions: {
ecmaVersion: "latest",
},
rules: {
"arrow-spacing": "error",
"block-spacing": "error",
"brace-style": "error",
"comma-dangle": ["error", "only-multiline"],
"comma-spacing": "error",
"comma-style": ["error", "last"],
"curly": ["error", "multi-line", "consistent"],
"eol-last": "error",
"func-call-spacing": "error",
"function-call-argument-newline": ["error", "consistent"],
"function-paren-newline": ["error", "consistent"],
"indent": ["error", 4],
"key-spacing": "error",
"keyword-spacing": "error",
"linebreak-style": ["error", "unix"],
"no-extra-semi": "error",
"no-mixed-spaces-and-tabs": "error",
"no-multi-spaces": "error",
"no-redeclare": ["error", {builtinGlobals: false}],
"no-trailing-spaces": "error",
"no-unused-vars": "off",
"no-whitespace-before-property": "error",
"object-curly-newline": ["error", {consistent: true, multiline: true}],
"object-curly-spacing": ["error", "never"],
"operator-linebreak": ["error", "after"],
"quote-props": ["error", "consistent-as-needed"],
"semi": ["error", "always"],
"semi-spacing": "error",
"semi-style": ["error", "last"],
"space-before-blocks": "error",
"space-before-function-paren": ["error", "never"],
"space-in-parens": ["error", "never"],
"space-infix-ops": "error",
"space-unary-ops": "error",
"switch-colon-spacing": "error",
"template-curly-spacing": ["error", "never"],
"unicode-bom": "error",
},
globals: {
//script.js
gradioApp: "readonly",
executeCallbacks: "readonly",
onAfterUiUpdate: "readonly",
onOptionsChanged: "readonly",
onUiLoaded: "readonly",
onUiUpdate: "readonly",
uiCurrentTab: "writable",
uiElementInSight: "readonly",
uiElementIsVisible: "readonly",
//ui.js
opts: "writable",
all_gallery_buttons: "readonly",
selected_gallery_button: "readonly",
selected_gallery_index: "readonly",
switch_to_txt2img: "readonly",
switch_to_img2img_tab: "readonly",
switch_to_img2img: "readonly",
switch_to_sketch: "readonly",
switch_to_inpaint: "readonly",
switch_to_inpaint_sketch: "readonly",
switch_to_extras: "readonly",
get_tab_index: "readonly",
create_submit_args: "readonly",
restart_reload: "readonly",
updateInput: "readonly",
//extraNetworks.js
requestGet: "readonly",
popup: "readonly",
// from python
localization: "readonly",
// progrssbar.js
randomId: "readonly",
requestProgress: "readonly",
// imageviewer.js
modalPrevImage: "readonly",
modalNextImage: "readonly",
// token-counters.js
setupTokenCounters: "readonly",
}
};
# Apply ESlint
9c54b78d9dde5601e916f308d9a9d6953ec39430
\ No newline at end of file
......@@ -43,10 +43,19 @@ body:
- type: input
id: commit
attributes:
label: Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
label: Version or Commit where the problem happens
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations:
required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown
id: platforms
attributes:
......@@ -59,6 +68,35 @@ body:
- iOS
- Android
- Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown
id: cross_attention_opt
attributes:
label: Cross attention optimization
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
multiple: false
options:
- Automatic
- xformers
- sdp-no-mem
- sdp
- Doggettx
- V1
- InvokeAI
- "None "
validations:
required: true
- type: dropdown
id: browsers
attributes:
......
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
## Description
If you have a large change, pay special attention to this paragraph:
* a simple description of what you're trying to accomplish
* a summary of changes in code
* which issues it fixes, if any
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
## Screenshots/videos:
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
**Describe what this pull request is trying to achieve.**
## Checklist:
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
**Additional notes and description of your changes**
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
**Environment this was tested in**
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux]
- Browser: [e.g. chrome, safari]
- Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes**
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
This is **required** for anything that touches the user interface.
\ No newline at end of file
- [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
- [ ] I have performed a self-review of my own code
- [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style)
- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests)
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests
on:
- push
- pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs:
lint:
lint-python:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: 3.10.6
cache: pip
cache-dependency-path: |
**/requirements*txt
- name: Install PyLint
run: |
python -m pip install --upgrade pip
pip install pylint
# This lets PyLint check to see if it can resolve imports
- name: Install dependencies
run: |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
python-version: 3.11
# NB: there's no cache: pip here since we're not installing anything
# from the requirements.txt file(s) in the repository; it's faster
# not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies.
- name: Install Ruff
run: pip install ruff==0.0.272
- name: Run Ruff
run: ruff .
lint-js:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Install Node.js
uses: actions/setup-node@v3
with:
node-version: 18
- run: npm i --ci
- run: npm run lint
......@@ -17,13 +17,54 @@ jobs:
cache: pip
cache-dependency-path: |
**/requirements*txt
launch.py
- name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
- name: Setup environment
run: python launch.py --skip-torch-cuda-test --exit
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1"
- name: Start test server
run: >
python -m coverage run
--data-file=.coverage.server
launch.py
--skip-prepare-environment
--skip-torch-cuda-test
--test-server
--no-half
--disable-opt-split-attention
--use-cpu all
--api-server-stop
2>&1 | tee output.txt &
- name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
run: |
wait-for-it --service 127.0.0.1:7860 -t 600
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage
run: |
python -m coverage combine .coverage*
python -m coverage report -i
python -m coverage html -i
- name: Upload main app output
uses: actions/upload-artifact@v3
if: always()
with:
name: output
path: output.txt
- name: Upload coverage HTML
uses: actions/upload-artifact@v3
if: always()
with:
name: stdout-stderr
path: |
test/stdout.txt
test/stderr.txt
name: htmlcov
path: htmlcov
......@@ -34,3 +34,6 @@ notification.mp3
/test/stderr.txt
/cache.json*
/config_states/
/node_modules
/package-lock.json
/.coverage*
This diff is collapsed.
......@@ -15,7 +15,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Attention, specify parts of text that the model should pay more attention to
- a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
- select text and press `Ctrl+Up` or `Ctrl+Down` (or `Command+Up` or `Command+Down` if you're on a MacOS) to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion
......@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
2. Run `update.bat`.
3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
......@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
......@@ -88,7 +88,7 @@ class LDSR:
x_t = None
logs = None
for n in range(n_runs):
for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
......@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps)
eta = 1.0
down_sample_method = 'Lanczos'
gc.collect()
if torch.cuda.is_available:
......@@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
......@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path):
example = dict()
example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
......@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict()
log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
......@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
except Exception:
pass
log["sample"] = x_sample
......
import os
import sys
import traceback
from basicsr.utils.download_util import load_file_from_url
from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
......@@ -44,22 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
try:
return LDSR(model, yaml)
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
except Exception:
print("Error importing LDSR:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
return LDSR(model, yaml)
def do_upscale(self, img, path):
ldsr = self.load_model(path)
if ldsr is None:
print("NO LDSR!")
try:
ldsr = self.load_model(path)
except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
......
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
......@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
......@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
......@@ -76,18 +81,19 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
if missing:
print(f"Missing Keys: {missing}")
if unexpected:
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
......@@ -165,7 +171,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
......@@ -232,7 +238,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
......@@ -249,7 +255,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
......@@ -264,7 +271,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
......@@ -282,5 +289,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant)
return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel)
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
ldm.models.autoencoder.VQModel = VQModel
ldm.models.autoencoder.VQModelInterface = VQModelInterface
......@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
......@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
......@@ -182,22 +182,22 @@ class DDPMV1(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
if missing:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
if unexpected:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
......@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
......@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x
# get diffusion row
diffusion_row = list()
diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
......@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
......@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
......@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
......@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
......@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids
assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
......@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
[x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
......@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates
@torch.no_grad()
......@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates:
return img, intermediates
......@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
[x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
......@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None
log = dict()
log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
......@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
......@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
......@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
......@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img
return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
# where the license is as follows:
#
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE./
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits is False, "Only for interface compatible with Gumbel"
assert return_logits is False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
......@@ -9,19 +9,37 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0
assert params.items
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers)
if shared.opts.lora_add_hashes_to_infotext:
lora_hashes = []
for item in lora.loaded_loras:
shorthash = item.lora_on_disk.shorthash
if not shorthash:
continue
alias = item.mentioned_name
if not alias:
continue
alias = alias.replace(":", "").replace(",", "")
lora_hashes.append(f"{alias}: {shorthash}")
if lora_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
def deactivate(self, p):
pass
import glob
import os
import re
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
......@@ -77,9 +76,9 @@ class LoraOnDisk:
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
_, ext = os.path.splitext(filename)
if ext.lower() == ".safetensors":
if self.is_safetensors:
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
......@@ -95,14 +94,43 @@ class LoraOnDisk:
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
available_lora_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
return self.name
else:
return self.alias
class LoraModule:
def __init__(self, name):
def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add lora to prompt - can be either name or an alias"""
class LoraUpDownModule:
def __init__(self):
......@@ -127,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model):
sd_model.lora_layer_mapping = lora_layer_mapping
def load_lora(name, filename):
lora = LoraModule(name)
lora.mtime = os.path.getmtime(filename)
def load_lora(name, lora_on_disk):
lora = LoraModule(name, lora_on_disk)
lora.mtime = os.path.getmtime(lora_on_disk.filename)
sd = sd_models.read_state_dict(filename)
sd = sd_models.read_state_dict(lora_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
......@@ -177,7 +205,7 @@ def load_lora(name, filename):
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
......@@ -189,10 +217,10 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
......@@ -207,30 +235,41 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]):
if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
failed_to_load_loras = []
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
try:
lora = load_lora(name, lora_on_disk.filename)
lora = load_lora(name, lora_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
continue
lora.mentioned_name = name
lora_on_disk.read_hash()
if lora is None:
failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
def lora_calc_updown(lora, module, target):
with torch.no_grad():
......@@ -314,7 +353,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names)
self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
......@@ -348,8 +387,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
self.lora_current_names = ()
self.lora_weights_backup = None
def lora_Linear_forward(self, input):
......@@ -398,17 +437,22 @@ def list_available_loras():
available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
forbidden_lora_aliases.update({"none": 1})
available_lora_hash_lookup.clear()
forbidden_lora_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower):
for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename)
try:
entry = LoraOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
continue
available_loras[name] = entry
......@@ -428,7 +472,7 @@ def infotext_pasted(infotext, params):
added = []
for k, v in params.items():
for k in params:
if not k.startswith("AddNet Model "):
continue
......@@ -452,8 +496,10 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
available_loras = {}
available_lora_aliases = {}
available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
......
import re
import torch
import gradio as gr
from fastapi import FastAPI
......@@ -53,8 +55,9 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
......@@ -77,6 +80,37 @@ def api_loras(_: gr.Blocks, app: FastAPI):
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return lora.list_available_loras()
script_callbacks.on_app_started(api_loras)
re_lora = re.compile("<lora:([^:]+):")
def infotext_pasted(infotext, d):
hashes = d.get("Lora hashes")
if not hashes:
return
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def lora_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
if lora_on_disk is None:
return m.group(0)
return f'<lora:{lora_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)
......@@ -13,13 +13,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
lora.list_available_loras()
def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
path, ext = os.path.splitext(lora_on_disk.filename)
if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
alias = name
else:
alias = lora_on_disk.alias
alias = lora_on_disk.get_alias()
yield {
"name": name,
......@@ -30,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
}
def allowed_directories_for_previews(self):
......
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
......@@ -29,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = []
add_model2 = True
for file in model_paths:
if "http" in file:
if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
......@@ -39,8 +36,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
......@@ -91,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
......@@ -121,20 +118,27 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings():
import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
script_callbacks.on_ui_settings(on_ui_settings)
......@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns:
output: tensor shape [b h w c]
"""
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
......@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2))
if self.type != 'W':
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output
def relative_embedding(self):
......@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
nn.init.constant_(m.weight, 1.0)
import contextlib
import os
import sys
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir')
......@@ -20,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
if "http" in model:
if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
......@@ -38,42 +35,45 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
model = self.load_model(model_file)
if model is None:
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
except:
except Exception:
pass
return img
def load_model(self, path, scale=4):
if "http" in path:
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
if path.startswith("http"):
filename = modelloader.load_file_from_url(
url=path,
model_dir=self.model_download_path,
file_name=f"{self.model_name.replace(' ', '_')}.pth",
)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = net(
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
......@@ -151,7 +151,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
......
......@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
......@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
......@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
......
......@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
......@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
......@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
return attn_mask
def forward(self, x, x_size):
H, W = x_size
......@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
......@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
......@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
......@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
......@@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
......@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential):
"""Upsample module.
......@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m)
super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
......@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class Swin2SR(nn.Module):
r""" Swin2SR
......@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True,
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
......@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
......@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
)
self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
......@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
......@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
)
self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
......@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
......@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
......@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size)
return x
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
......@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
return x
def forward(self, x):
H, W = x.shape[2:]
......@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
......@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
......@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
......@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
\ No newline at end of file
print(x.shape)
This diff is collapsed.
import gradio as gr
from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
.canvas-tooltip-info {
position: absolute;
top: 10px;
left: 10px;
cursor: help;
background-color: rgba(0, 0, 0, 0.3);
width: 20px;
height: 20px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
flex-direction: column;
z-index: 100;
}
.canvas-tooltip-info::after {
content: '';
display: block;
width: 2px;
height: 7px;
background-color: white;
margin-top: 2px;
}
.canvas-tooltip-info::before {
content: '';
display: block;
width: 2px;
height: 2px;
background-color: white;
}
.canvas-tooltip-content {
display: none;
background-color: #f9f9f9;
color: #333;
border: 1px solid #ddd;
padding: 15px;
position: absolute;
top: 40px;
left: 10px;
width: 250px;
font-size: 16px;
opacity: 0;
border-radius: 8px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 100;
}
.canvas-tooltip:hover .canvas-tooltip-content {
display: block;
animation: fadeIn 0.5s;
opacity: 1;
}
@keyframes fadeIn {
from {opacity: 0;}
to {opacity: 1;}
}
import gradio as gr
from modules import scripts, shared, ui_components, ui_settings
from modules.ui_components import FormColumn
class ExtraOptionsSection(scripts.Script):
section = "extra_options"
def __init__(self):
self.comps = None
self.setting_names = None
def title(self):
return "Extra options"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
self.comps.append(comp)
self.setting_names.append(setting_name)
def get_settings_values():
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
return self.comps
def before_process(self, p, *args):
for name, value in zip(self.setting_names, args):
if name not in p.override_settings:
p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
}))
......@@ -4,39 +4,39 @@
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElt) {
var counts = {};
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
counts[bracket] = (counts[bracket] || 0) + 1;
});
var errors = [];
var counts = {};
(textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
counts[bracket] = (counts[bracket] || 0) + 1;
});
var errors = [];
function checkPair(open, close, kind) {
if (counts[open] !== counts[close]) {
errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
function checkPair(open, close, kind) {
if (counts[open] !== counts[close]) {
errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
}
}
}
checkPair('(', ')', 'round brackets');
checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0);
checkPair('(', ')', 'round brackets');
checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0);
}
function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter)
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter);
if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
}
if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
}
}
onUiLoaded(function () {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
onUiLoaded(function() {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});
<div class='card' style={style} onclick={card_clicked}>
<div class='card' style={style} onclick={card_clicked} {sort_keys}>
{background_image}
{metadata_button}
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
<span class='description'>{description}</span>
</div>
</div>
<div>
<a href="/docs">API</a>
<a href="{api_docs}">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="#" onclick="showProfile('./internal/profile-startup'); return false;">Startup profile</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
<br />
......
......@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
</pre>
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
<pre>
MIT License
Copyright (c) 2023 Ollin Boer Bohan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
\ No newline at end of file
let currentWidth = null;
let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0);
function dimensionChange(e, is_width, is_height){
if(is_width){
currentWidth = e.target.value*1.0
}
if(is_height){
currentHeight = e.target.value*1.0
}
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){
return;
}
var targetElement = null;
var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
}
if(targetElement){
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(!arPreviewRect){
arPreviewRect = document.createElement('div')
arPreviewRect.id = "imageARPreview";
gradioApp().appendChild(arPreviewRect)
}
var viewportOffset = targetElement.getBoundingClientRect();
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
var scaledx = targetElement.naturalWidth*viewportscale
var scaledy = targetElement.naturalHeight*viewportscale
var cleintRectTop = (viewportOffset.top+window.scrollY)
var cleintRectLeft = (viewportOffset.left+window.scrollX)
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
var arscaledx = currentWidth*arscale
var arscaledy = currentHeight*arscale
var arRectTop = cleintRectCentreY-(arscaledy/2)
var arRectLeft = cleintRectCentreX-(arscaledx/2)
var arRectWidth = arscaledx
var arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px';
arPreviewRect.style.width = arRectWidth+'px';
arPreviewRect.style.height = arRectHeight+'px';
clearTimeout(arFrameTimeout);
arFrameTimeout = setTimeout(function(){
arPreviewRect.style.display = 'none';
},2000);
arPreviewRect.style.display = 'block';
}
}
onUiUpdate(function(){
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(arPreviewRect){
arPreviewRect.style.display = 'none';
}
var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){
let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){
var is_width = e.parentElement.id == "img2img_width"
var is_height = e.parentElement.id == "img2img_height"
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch')
}
if(is_width){
currentWidth = e.value*1.0
}
if(is_height){
currentHeight = e.value*1.0
}
})
}
}
});
let currentWidth = null;
let currentHeight = null;
let arFrameTimeout = setTimeout(function() {}, 0);
function dimensionChange(e, is_width, is_height) {
if (is_width) {
currentWidth = e.target.value * 1.0;
}
if (is_height) {
currentHeight = e.target.value * 1.0;
}
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if (!inImg2img) {
return;
}
var targetElement = null;
var tabIndex = get_tab_index('mode_img2img');
if (tabIndex == 0) { // img2img
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if (tabIndex == 1) { //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if (tabIndex == 2) { // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if (tabIndex == 3) { // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
}
if (targetElement) {
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (!arPreviewRect) {
arPreviewRect = document.createElement('div');
arPreviewRect.id = "imageARPreview";
gradioApp().appendChild(arPreviewRect);
}
var viewportOffset = targetElement.getBoundingClientRect();
var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
var scaledx = targetElement.naturalWidth * viewportscale;
var scaledy = targetElement.naturalHeight * viewportscale;
var cleintRectTop = (viewportOffset.top + window.scrollY);
var cleintRectLeft = (viewportOffset.left + window.scrollX);
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
var arscaledx = currentWidth * arscale;
var arscaledy = currentHeight * arscale;
var arRectTop = cleintRectCentreY - (arscaledy / 2);
var arRectLeft = cleintRectCentreX - (arscaledx / 2);
var arRectWidth = arscaledx;
var arRectHeight = arscaledy;
arPreviewRect.style.top = arRectTop + 'px';
arPreviewRect.style.left = arRectLeft + 'px';
arPreviewRect.style.width = arRectWidth + 'px';
arPreviewRect.style.height = arRectHeight + 'px';
clearTimeout(arFrameTimeout);
arFrameTimeout = setTimeout(function() {
arPreviewRect.style.display = 'none';
}, 2000);
arPreviewRect.style.display = 'block';
}
}
onAfterUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (arPreviewRect) {
arPreviewRect.style.display = 'none';
}
var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block";
if (inImg2img) {
let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e) {
var is_width = e.parentElement.id == "img2img_width";
var is_height = e.parentElement.id == "img2img_height";
if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
e.addEventListener('input', function(e) {
dimensionChange(e, is_width, is_height);
});
e.classList.add('scrollwatch');
}
if (is_width) {
currentWidth = e.value * 1.0;
}
if (is_height) {
currentHeight = e.value * 1.0;
}
});
}
}
});
This diff is collapsed.
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
function isValidImageList( files ) {
function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
}
function dropReplaceImage( imgWrap, files ) {
if ( ! isValidImageList( files ) ) {
function dropReplaceImage(imgWrap, files) {
if (!isValidImageList(files)) {
return;
}
......@@ -14,46 +14,61 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) {
if ( files.length === 0 ) {
if (fileInput) {
if (files.length === 0) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change'));
fileInput.dispatchEvent(new Event('change'));
}
};
if ( imgWrap.closest('#pnginfo_image') ) {
if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch;
window.fetch = async (input, options) => {
window.fetch = async(input, options) => {
const response = await oldFetch(input, options);
if ( 'api/predict/' === input ) {
if ('api/predict/' === input) {
const content = await response.text();
window.fetch = oldFetch;
window.requestAnimationFrame( () => callback() );
window.requestAnimationFrame(() => callback());
return new Response(content, {
status: response.status,
statusText: response.statusText,
headers: response.headers
})
});
}
return response;
};
};
} else {
window.requestAnimationFrame( () => callback() );
window.requestAnimationFrame(() => callback());
}
}
function eventHasFiles(e) {
if (!e.dataTransfer || !e.dataTransfer.files) return false;
if (e.dataTransfer.files.length > 0) return true;
if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
return false;
}
function dragDropTargetIsPrompt(target) {
if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
return false;
}
window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return;
}
if (!eventHasFiles(e)) return;
var targetImage = target.closest('[data-testid="image"]');
if (!dragDropTargetIsPrompt(target) && !targetImage) return;
e.stopPropagation();
e.preventDefault();
e.dataTransfer.dropEffect = 'copy';
......@@ -61,37 +76,55 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => {
const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) {
return;
if (!eventHasFiles(e)) return;
if (dragDropTargetIsPrompt(target)) {
e.stopPropagation();
e.preventDefault();
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
}
const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) {
var targetImage = target.closest('[data-testid="image"]');
if (targetImage) {
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage(targetImage, files);
return;
}
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
dropReplaceImage( imgWrap, files );
});
window.addEventListener('paste', e => {
const files = e.clipboardData.files;
if ( ! isValidImageList( files ) ) {
if (!isValidImageList(files)) {
return;
}
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(el => uiElementIsVisible(el));
if ( ! visibleImageFields.length ) {
.filter(el => uiElementIsVisible(el))
.sort((a, b) => uiElementInSight(b) - uiElementInSight(a));
if (!visibleImageFields.length) {
return;
}
const firstFreeImageField = visibleImageFields
.filter(el => el.querySelector('input[type=file]'))?.[0];
dropReplaceImage(
firstFreeImageField ?
firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1]
, files );
firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1]
, files
);
});
function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0];
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp"
let isMinus = event.key == "ArrowDown"
if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
let text = target.value;
function selectCurrentParenthesisBlock(OPEN, CLOSE){
if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
}
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
function selectCurrentWord(){
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
}
event.preventDefault();
var closeCharacter = ')'
var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
delta = opts.keyedit_precision_extra
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
// do not include spaces at the end
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
selectionEnd -= 1;
}
if(selectionStart == selectionEnd){
return
}
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1;
selectionEnd += 1;
}
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus();
target.value = text;
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
updateInput(target)
}
addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
function keyupEditAttention(event) {
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!(event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp";
let isMinus = event.key == "ArrowDown";
if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
let text = target.value;
function selectCurrentParenthesisBlock(OPEN, CLOSE) {
if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
}
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
}
event.preventDefault();
var closeCharacter = ')';
var delta = opts.keyedit_precision_attention;
if (selectionStart > 0 && text[selectionStart - 1] == '<') {
closeCharacter = '>';
delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
// do not include spaces at the end
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1;
}
if (selectionStart == selectionEnd) {
return;
}
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1;
selectionEnd += 1;
}
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12));
if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) {
var endParenPos = text.substring(selectionEnd).indexOf(')');
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
}
target.focus();
target.value = text;
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
updateInput(target);
}
addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
/* alt+left/right moves text in prompt */
function keyupEditOrder(event) {
if (!opts.keyedit_move) return;
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!event.altKey) return;
event.preventDefault()
let isLeft = event.key == "ArrowLeft";
let isRight = event.key == "ArrowRight";
if (!isLeft && !isRight) return;
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
let text = target.value;
let items = text.split(",");
let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
let range = indexEnd - indexStart + 1;
if (isLeft && indexStart > 0) {
items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
target.value = items.join();
target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
target.selectionEnd = items.slice(0, indexEnd).join().length;
} else if (isRight && indexEnd < items.length - 1) {
items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
target.value = items.join();
target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
}
event.preventDefault();
updateInput(target);
}
addEventListener('keydown', (event) => {
keyupEditOrder(event);
});
function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = []
var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substring(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
}
function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substring(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
var id = randomId()
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
})
return [id, JSON.stringify(disable)]
}
function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click()
}
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
if (config_state_name == "Current") {
return [false, config_state_name, config_restore_type];
}
let restored = "";
if (config_restore_type == "extensions") {
restored = "all saved extension versions";
} else if (config_restore_type == "webui") {
restored = "the webui version";
} else {
restored = "the webui version and all saved extension versions";
}
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
if (confirmed) {
restart_reload();
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
}
return [confirmed, config_state_name, config_restore_type];
}
function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = [];
var update = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7));
}
if (x.name.startsWith("update_") && x.checked) {
update.push(x.name.substring(7));
}
});
restart_reload();
return [JSON.stringify(disable), JSON.stringify(update), disable_all];
}
function extensions_check() {
var disable = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7));
}
});
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading...";
});
var id = randomId();
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
});
return [id, JSON.stringify(disable)];
}
function install_extension_from_index(button, url) {
button.disabled = "disabled";
button.value = "Installing...";
var textarea = gradioApp().querySelector('#extension_to_install textarea');
textarea.value = url;
updateInput(textarea);
gradioApp().querySelector('#install_extension_button').click();
}
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
if (config_state_name == "Current") {
return [false, config_state_name, config_restore_type];
}
let restored = "";
if (config_restore_type == "extensions") {
restored = "all saved extension versions";
} else if (config_restore_type == "webui") {
restored = "the webui version";
} else {
restored = "the webui version and all saved extension versions";
}
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
if (confirmed) {
restart_reload();
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading...";
});
}
return [confirmed, config_state_name, config_restore_type];
}
function toggle_all_extensions(event) {
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
checkbox_el.checked = event.target.checked;
});
}
function toggle_extension() {
let all_extensions_toggled = true;
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
if (!checkbox_el.checked) {
all_extensions_toggled = false;
break;
}
}
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
}
This diff is collapsed.
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img")
}
if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img")
}
if (!modal) {
modal = gradioApp().getElementById('lightboxModal')
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
}
onAfterUiUpdate(function() {
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img");
}
if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img");
}
if (!modal) {
modal = gradioApp().getElementById('lightboxModal');
modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
}
});
let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
});
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
}
});
});
function attachGalleryListeners(tab_name) {
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click()
});
return gallery;
var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
gradioApp().getElementById(tab_name + "_generation_info_button").click();
}
});
return gallery;
}
This diff is collapsed.
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
function setInactive(elem, inactive){
elem.classList.toggle('inactive', !!inactive)
}
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
function setInactive(elem, inactive) {
elem.classList.toggle('inactive', !!inactive);
}
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
}
......@@ -4,17 +4,16 @@
*/
function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) {
canvases_fixed = false; // TODO: this is unused..?
window.removeEventListener( 'resize', imageMaskResize );
return;
if (!canvases.length) {
window.removeEventListener('resize', imageMaskResize);
return;
}
const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) {
previewImage.addEventListener( 'load', imageMaskResize);
if (!previewImage.complete) {
previewImage.addEventListener('load', imageMaskResize);
return;
}
......@@ -24,15 +23,15 @@ function imageMaskResize() {
const nh = previewImage.naturalHeight;
const portrait = nh > nw;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`;
wrapper.style.top = `0px`;
canvases.forEach( c => {
canvases.forEach(c => {
c.style.width = c.style.height = '';
c.style.maxWidth = '100%';
c.style.maxHeight = '100%';
......@@ -40,5 +39,5 @@ function imageMaskResize() {
});
}
onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize);
onAfterUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize);
window.onload = (function(){
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
e.stopPropagation();
e.preventDefault();
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
});
});
This diff is collapsed.
let gamepads = [];
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
setInterval(async () => {
gamepads[index] = setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
......@@ -14,7 +16,7 @@ window.addEventListener('gamepadconnected', (e) => {
}
if (isWaiting) {
await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0]
const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
......@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
}, 10);
});
window.addEventListener('gamepaddisconnected', (e) => {
clearInterval(gamepads[e.gamepad.index]);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -2,7 +2,6 @@ import os
import re
import torch
from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
......@@ -79,7 +78,7 @@ class DeepDanbooru:
res = []
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .sampler import UniPCSampler
from .sampler import UniPCSampler # noqa: F401
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment